You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/07/03 15:30:00 UTC

[spark] branch branch-3.2 updated: [SPARK-35940][SQL] Refactor EquivalentExpressions to make it more efficient

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new ec84982  [SPARK-35940][SQL] Refactor EquivalentExpressions to make it more efficient
ec84982 is described below

commit ec8498219196cab1fc6dcb559541883e9ef0e77a
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Sat Jul 3 08:28:44 2021 -0700

    [SPARK-35940][SQL] Refactor EquivalentExpressions to make it more efficient
    
    ### What changes were proposed in this pull request?
    
    This PR uses 2 ideas to make `EquivalentExpressions` more efficient:
    1. do not keep all the equivalent expressions, we only need a count
    2. track the "height" of common subexpressions, to quickly do child-parent sort, and filter out non-child expressions in `addCommonExprs`
    
    This PR also fixes several small bugs (exposed by the refactoring), please see PR comments.
    
    ### Why are the changes needed?
    
    code cleanup and small perf improvement
    
    ### Does this PR introduce _any_ user-facing change?
    
    no
    
    ### How was this patch tested?
    
    existing tests
    
    Closes #33142 from cloud-fan/codegen.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
    (cherry picked from commit e6ce220690937718bdd9046a9ed69bf2feb57595)
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 .../expressions/EquivalentExpressions.scala        | 177 +++++++++++----------
 .../sql/catalyst/expressions/Expression.scala      |   2 +-
 .../expressions/SubExprEvaluationRuntime.scala     |  17 +-
 .../expressions/codegen/CodeGenerator.scala        |  71 +++++----
 .../spark/sql/catalyst/planning/patterns.scala     |   4 +-
 .../catalyst/expressions/CodeGenerationSuite.scala |  28 ++--
 .../SubexpressionEliminationSuite.scala            | 152 ++++++++----------
 .../execution/aggregate/HashAggregateExec.scala    |   2 +-
 8 files changed, 218 insertions(+), 235 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index dd7193b..ef04e88 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -29,20 +29,8 @@ import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
  * considered equal if for the same input(s), the same result is produced.
  */
 class EquivalentExpressions {
-  /**
-   * Wrapper around an Expression that provides semantic equality.
-   */
-  case class Expr(e: Expression) {
-    override def equals(o: Any): Boolean = o match {
-      case other: Expr => e.semanticEquals(other.e)
-      case _ => false
-    }
-
-    override def hashCode: Int = e.semanticHash()
-  }
-
   // For each expression, the set of equivalent expressions.
-  private val equivalenceMap = mutable.HashMap.empty[Expr, mutable.ArrayBuffer[Expression]]
+  private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
 
   /**
    * Adds each expression to this data structure, grouping them with existing equivalent
@@ -50,29 +38,20 @@ class EquivalentExpressions {
    * Returns true if there was already a matching expression.
    */
   def addExpr(expr: Expression): Boolean = {
-    if (expr.deterministic) {
-      val e: Expr = Expr(expr)
-      val f = equivalenceMap.get(e)
-      if (f.isDefined) {
-        f.get += expr
-        true
-      } else {
-        equivalenceMap.put(e, mutable.ArrayBuffer(expr))
-        false
-      }
-    } else {
-      false
-    }
+    addExprToMap(expr, equivalenceMap)
   }
 
-  private def addExprToSet(expr: Expression, set: mutable.Set[Expr]): Boolean = {
+  private def addExprToMap(
+      expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Boolean = {
     if (expr.deterministic) {
-      val e = Expr(expr)
-      if (set.contains(e)) {
-        true
-      } else {
-        set.add(e)
-        false
+      val wrapper = ExpressionEquals(expr)
+      map.get(wrapper) match {
+        case Some(stats) =>
+          stats.useCount += 1
+          true
+        case _ =>
+          map.put(wrapper, ExpressionStats(expr)())
+          false
       }
     } else {
       false
@@ -93,25 +72,33 @@ class EquivalentExpressions {
    */
   private def addCommonExprs(
       exprs: Seq[Expression],
-      addFunc: Expression => Boolean = addExpr): Unit = {
-    val exprSetForAll = mutable.Set[Expr]()
-    addExprTree(exprs.head, addExprToSet(_, exprSetForAll))
-
-    val candidateExprs = exprs.tail.foldLeft(exprSetForAll) { (exprSet, expr) =>
-      val otherExprSet = mutable.Set[Expr]()
-      addExprTree(expr, addExprToSet(_, otherExprSet))
-      exprSet.intersect(otherExprSet)
+      map: mutable.HashMap[ExpressionEquals, ExpressionStats]): Unit = {
+    assert(exprs.length > 1)
+    var localEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
+    addExprTree(exprs.head, localEquivalenceMap)
+
+    exprs.tail.foreach { expr =>
+      val otherLocalEquivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]
+      addExprTree(expr, otherLocalEquivalenceMap)
+      localEquivalenceMap = localEquivalenceMap.filter { case (key, _) =>
+        otherLocalEquivalenceMap.contains(key)
+      }
     }
 
-    // Not all expressions in the set should be added. We should filter out the related
-    // children nodes.
-    val commonExprSet = candidateExprs.filter { candidateExpr =>
-      candidateExprs.forall { expr =>
-        expr == candidateExpr || expr.e.find(_.semanticEquals(candidateExpr.e)).isEmpty
+    localEquivalenceMap.foreach { case (commonExpr, state) =>
+      val possibleParents = localEquivalenceMap.filter { case (_, v) => v.height > state.height }
+      val notChild = possibleParents.forall { case (k, _) =>
+        k == commonExpr || k.e.find(_.semanticEquals(commonExpr.e)).isEmpty
+      }
+      if (notChild) {
+        // If the `commonExpr` already appears in the equivalence map, calling `addExprTree` will
+        // increase the `useCount` and mark it as a common subexpression. Otherwise, `addExprTree`
+        // will recursively add `commonExpr` and its descendant to the equivalence map, in case
+        // they also appear in other places. For example, `If(a + b > 1, a + b + c, a + b + c)`,
+        // `a + b` also appears in the condition and should be treated as common subexpression.
+        addExprTree(commonExpr.e, map)
       }
     }
-
-    commonExprSet.foreach(expr => addExprTree(expr.e, addFunc))
   }
 
   // There are some special expressions that we should not recurse into all of its children.
@@ -135,6 +122,7 @@ class EquivalentExpressions {
   // For some special expressions we cannot just recurse into all of its children, but we can
   // recursively add the common expressions shared between all of its children.
   private def commonChildrenToRecurse(expr: Expression): Seq[Seq[Expression]] = expr match {
+    case _: CodegenFallback => Nil
     case i: If => Seq(Seq(i.trueValue, i.falseValue))
     case c: CaseWhen =>
       // We look at subexpressions in conditions and values of `CaseWhen` separately. It is
@@ -142,7 +130,13 @@ class EquivalentExpressions {
       // if it is shared among conditions, but it doesn't need to be shared in values. Similarly,
       // a subexpression among values doesn't need to be in conditions because no matter which
       // condition is true, it will be evaluated.
-      val conditions = c.branches.tail.map(_._1)
+      val conditions = if (c.branches.length > 1) {
+        c.branches.map(_._1)
+      } else {
+        // If there is only one branch, the first condition is already covered by
+        // `childrenToRecurse` and we should exclude it here.
+        Nil
+      }
       // For an expression to be in all branch values of a CaseWhen statement, it must also be in
       // the elseValue.
       val values = if (c.elseValue.nonEmpty) {
@@ -150,8 +144,11 @@ class EquivalentExpressions {
       } else {
         Nil
       }
+
       Seq(conditions, values)
-    case c: Coalesce => Seq(c.children.tail)
+    // If there is only one child, the first child is already covered by
+    // `childrenToRecurse` and we should exclude it here.
+    case c: Coalesce if c.children.length > 1 => Seq(c.children)
     case _ => Nil
   }
 
@@ -161,7 +158,7 @@ class EquivalentExpressions {
    */
   def addExprTree(
       expr: Expression,
-      addFunc: Expression => Boolean = addExpr): Unit = {
+      map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = {
     val skip = expr.isInstanceOf[LeafExpression] ||
       // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
       // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
@@ -170,27 +167,30 @@ class EquivalentExpressions {
       // can cause error like NPE.
       (expr.isInstanceOf[PlanExpression[_]] && TaskContext.get != null)
 
-    if (!skip && !addFunc(expr)) {
-      childrenToRecurse(expr).foreach(addExprTree(_, addFunc))
-      commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, addFunc))
+    if (!skip && !addExprToMap(expr, map)) {
+      childrenToRecurse(expr).foreach(addExprTree(_, map))
+      commonChildrenToRecurse(expr).filter(_.nonEmpty).foreach(addCommonExprs(_, map))
     }
   }
 
   /**
-   * Returns all of the expression trees that are equivalent to `e`. Returns
-   * an empty collection if there are none.
+   * Returns the state of the given expression in the `equivalenceMap`. Returns None if there is no
+   * equivalent expressions.
    */
-  def getEquivalentExprs(e: Expression): Seq[Expression] = {
-    equivalenceMap.getOrElse(Expr(e), Seq.empty).toSeq
+  def getExprState(e: Expression): Option[ExpressionStats] = {
+    equivalenceMap.get(ExpressionEquals(e))
+  }
+
+  // Exposed for testing.
+  private[sql] def getAllExprStates(count: Int = 0): Seq[ExpressionStats] = {
+    equivalenceMap.values.filter(_.useCount > count).toSeq.sortBy(_.height)
   }
 
   /**
-   * Returns all the equivalent sets of expressions which appear more than given `repeatTimes`
-   * times.
+   * Returns a sequence of expressions that more than one equivalent expressions.
    */
-  def getAllEquivalentExprs(repeatTimes: Int = 0): Seq[Seq[Expression]] = {
-    equivalenceMap.values.map(_.toSeq).filter(_.size > repeatTimes).toSeq
-      .sortBy(_.head)(new ExpressionContainmentOrdering)
+  def getCommonSubexpressions: Seq[Expression] = {
+    getAllExprStates(1).map(_.expr)
   }
 
   /**
@@ -198,37 +198,40 @@ class EquivalentExpressions {
    * equivalent expressions with cardinality 1.
    */
   def debugString(all: Boolean = false): String = {
-    val sb: mutable.StringBuilder = new StringBuilder()
+    val sb = new java.lang.StringBuilder()
     sb.append("Equivalent expressions:\n")
-    equivalenceMap.foreach { case (k, v) =>
-      if (all || v.length > 1) {
-        sb.append("  " + v.mkString(", ")).append("\n")
-      }
+    equivalenceMap.values.filter(stats => all || stats.useCount > 1).foreach { stats =>
+      sb.append("  ").append(s"${stats.expr}: useCount = ${stats.useCount}").append('\n')
     }
     sb.toString()
   }
 }
 
 /**
- * Orders `Expression` by parent/child relations. The child expression is smaller
- * than parent expression. If there is child-parent relationships among the subexpressions,
- * we want the child expressions come first than parent expressions, so we can replace
- * child expressions in parent expressions with subexpression evaluation. Note that
- * this is not for general expression ordering. For example, two irrelevant or semantically-equal
- * expressions will be considered as equal by this ordering. But for the usage here, the order of
- * irrelevant expressions does not matter.
+ * Wrapper around an Expression that provides semantic equality.
  */
-class ExpressionContainmentOrdering extends Ordering[Expression] {
-  override def compare(x: Expression, y: Expression): Int = {
-    if (x.find(_.semanticEquals(y)).isDefined) {
-      // `y` is child expression of `x`.
-      1
-    } else if (y.find(_.semanticEquals(x)).isDefined) {
-      // `x` is child expression of `y`.
-      -1
-    } else {
-      // Irrelevant or semantically-equal expressions
-      0
-    }
+case class ExpressionEquals(e: Expression) {
+  override def equals(o: Any): Boolean = o match {
+    case other: ExpressionEquals => e.semanticEquals(other.e)
+    case _ => false
+  }
+
+  override def hashCode: Int = e.semanticHash()
+}
+
+/**
+ * A wrapper in place of using Seq[Expression] to record a group of equivalent expressions.
+ *
+ * This saves a lot of memory when there are a lot of expressions in a same equivalence group.
+ * Instead of appending to a mutable list/buffer of Expressions, just update the "flattened"
+ * useCount in this wrapper in-place.
+ */
+case class ExpressionStats(expr: Expression)(var useCount: Int = 1) {
+  // This is used to do a fast pre-check for child-parent relationship. For example, expr1 can
+  // only be a parent of expr2 if expr1.height is larger than expr2.height.
+  lazy val height = getHeight(expr)
+
+  private def getHeight(tree: Expression): Int = {
+    tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index c39db35..221f5ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -136,7 +136,7 @@ abstract class Expression extends TreeNode[Expression] {
    * @return [[ExprCode]]
    */
   def genCode(ctx: CodegenContext): ExprCode = {
-    ctx.subExprEliminationExprs.get(this).map { subExprState =>
+    ctx.subExprEliminationExprs.get(ExpressionEquals(this)).map { subExprState =>
       // This expression is repeated which means that the code to evaluate it has already been added
       // as a function before. In that case, we just re-use it.
       ExprCode(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
index 7886b65..fcc8ee6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SubExprEvaluationRuntime.scala
@@ -73,11 +73,11 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
    */
   private def replaceWithProxy(
       expr: Expression,
+      equivalentExpressions: EquivalentExpressions,
       proxyMap: IdentityHashMap[Expression, ExpressionProxy]): Expression = {
-    if (proxyMap.containsKey(expr)) {
-      proxyMap.get(expr)
-    } else {
-      expr.mapChildren(replaceWithProxy(_, proxyMap))
+    equivalentExpressions.getExprState(expr) match {
+      case Some(stats) if proxyMap.containsKey(stats.expr) => proxyMap.get(stats.expr)
+      case _ => expr.mapChildren(replaceWithProxy(_, equivalentExpressions, proxyMap))
     }
   }
 
@@ -91,9 +91,8 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
 
     val proxyMap = new IdentityHashMap[Expression, ExpressionProxy]
 
-    val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
-    commonExprs.foreach { e =>
-      val expr = e.head
+    val commonExprs = equivalentExpressions.getCommonSubexpressions
+    commonExprs.foreach { expr =>
       val proxy = ExpressionProxy(expr, proxyExpressionCurrentId, this)
       proxyExpressionCurrentId += 1
 
@@ -102,12 +101,12 @@ class SubExprEvaluationRuntime(cacheMaxEntries: Int) {
       // common expr2, ..., common expr n), we will insert into `proxyMap` some key/value
       // pairs like Map(common expr 1 -> proxy(common expr 1), ...,
       // common expr n -> proxy(common expr 1)).
-      e.map(proxyMap.put(_, proxy))
+      proxyMap.put(expr, proxy)
     }
 
     // Only adding proxy if we find subexpressions.
     if (!proxyMap.isEmpty) {
-      expressions.map(replaceWithProxy(_, proxyMap))
+      expressions.map(replaceWithProxy(_, equivalentExpressions, proxyMap))
     } else {
       expressions
     }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 3a5c46c..43ac2ab 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -83,9 +83,7 @@ object ExprCode {
  *                 particular subexpressions, instead of all at once. In the case, we need
  *                 to make sure we evaluate all children subexpressions too.
  */
-case class SubExprEliminationState(
-  eval: ExprCode,
-  children: Seq[SubExprEliminationState])
+case class SubExprEliminationState(eval: ExprCode, children: Seq[SubExprEliminationState])
 
 object SubExprEliminationState {
   def apply(eval: ExprCode): SubExprEliminationState = {
@@ -108,8 +106,8 @@ object SubExprEliminationState {
  *                              calling common subexpressions.
  */
 case class SubExprCodes(
-  states: Map[Expression, SubExprEliminationState],
-  exprCodesNeedEvaluate: Seq[ExprCode])
+    states: Map[ExpressionEquals, SubExprEliminationState],
+    exprCodesNeedEvaluate: Seq[ExprCode])
 
 /**
  * The main information about a new added function.
@@ -426,7 +424,8 @@ class CodegenContext extends Logging {
 
   // Foreach expression that is participating in subexpression elimination, the state to use.
   // Visible for testing.
-  private[expressions] var subExprEliminationExprs = Map.empty[Expression, SubExprEliminationState]
+  private[expressions] var subExprEliminationExprs =
+    Map.empty[ExpressionEquals, SubExprEliminationState]
 
   // The collection of sub-expression result resetting methods that need to be called on each row.
   private val subexprFunctions = mutable.ArrayBuffer.empty[String]
@@ -1031,7 +1030,7 @@ class CodegenContext extends Logging {
    * expressions and common expressions, instead of using the mapping in current context.
    */
   def withSubExprEliminationExprs(
-      newSubExprEliminationExprs: Map[Expression, SubExprEliminationState])(
+      newSubExprEliminationExprs: Map[ExpressionEquals, SubExprEliminationState])(
       f: => Seq[ExprCode]): Seq[ExprCode] = {
     val oldsubExprEliminationExprs = subExprEliminationExprs
     subExprEliminationExprs = newSubExprEliminationExprs
@@ -1098,29 +1097,30 @@ class CodegenContext extends Logging {
     // Create a clear EquivalentExpressions and SubExprEliminationState mapping
     val equivalentExpressions: EquivalentExpressions = new EquivalentExpressions
     val localSubExprEliminationExprsForNonSplit =
-      mutable.HashMap.empty[Expression, SubExprEliminationState]
+      mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]
 
     // Add each expression tree and compute the common subexpressions.
     expressions.foreach(equivalentExpressions.addExprTree(_))
 
     // Get all the expressions that appear at least twice and set up the state for subexpression
     // elimination.
-    val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
+    val commonExprs = equivalentExpressions.getCommonSubexpressions
 
     val nonSplitCode = {
       val allStates = mutable.ArrayBuffer.empty[SubExprEliminationState]
-      commonExprs.map { exprs =>
+      commonExprs.map { expr =>
         withSubExprEliminationExprs(localSubExprEliminationExprsForNonSplit.toMap) {
-          val eval = exprs.head.genCode(this)
+          val eval = expr.genCode(this)
           // Collects other subexpressions from the children.
           val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
-          exprs.head.foreach {
-            case e if subExprEliminationExprs.contains(e) =>
-              childrenSubExprs += subExprEliminationExprs(e)
-            case _ =>
+          expr.foreach { e =>
+            subExprEliminationExprs.get(ExpressionEquals(e)) match {
+              case Some(state) => childrenSubExprs += state
+              case _ =>
+            }
           }
           val state = SubExprEliminationState(eval, childrenSubExprs.toSeq)
-          exprs.foreach(localSubExprEliminationExprsForNonSplit.put(_, state))
+          localSubExprEliminationExprsForNonSplit.put(ExpressionEquals(expr), state)
           allStates += state
           Seq(eval)
         }
@@ -1133,7 +1133,7 @@ class CodegenContext extends Logging {
     // evaluate the outputs used more than twice. So we need to extract these variables used by
     // subexpressions and evaluate them before subexpressions.
     val (inputVarsForAllFuncs, exprCodesNeedEvaluate) = commonExprs.map { expr =>
-      val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr.head)
+      val (inputVars, exprCodes) = getLocalInputVariableValues(this, expr)
       (inputVars.toSeq, exprCodes.toSeq)
     }.unzip
 
@@ -1141,10 +1141,9 @@ class CodegenContext extends Logging {
     val (subExprsMap, exprCodes) = if (needSplit) {
       if (inputVarsForAllFuncs.map(calculateParamLengthFromExprValues).forall(isValidParamLength)) {
         val localSubExprEliminationExprs =
-          mutable.HashMap.empty[Expression, SubExprEliminationState]
+          mutable.HashMap.empty[ExpressionEquals, SubExprEliminationState]
 
-        commonExprs.zipWithIndex.foreach { case (exprs, i) =>
-          val expr = exprs.head
+        commonExprs.zipWithIndex.foreach { case (expr, i) =>
           val eval = withSubExprEliminationExprs(localSubExprEliminationExprs.toMap) {
             Seq(expr.genCode(this))
           }.head
@@ -1178,10 +1177,11 @@ class CodegenContext extends Logging {
 
           // Collects other subexpressions from the children.
           val childrenSubExprs = mutable.ArrayBuffer.empty[SubExprEliminationState]
-          exprs.head.foreach {
-            case e if localSubExprEliminationExprs.contains(e) =>
-              childrenSubExprs += localSubExprEliminationExprs(e)
-            case _ =>
+          expr.foreach { e =>
+            localSubExprEliminationExprs.get(ExpressionEquals(e)) match {
+              case Some(state) => childrenSubExprs += state
+              case _ =>
+            }
           }
 
           val inputVariables = inputVars.map(_.variableName).mkString(", ")
@@ -1189,7 +1189,7 @@ class CodegenContext extends Logging {
           val state = SubExprEliminationState(
             ExprCode(code, isNull, JavaCode.global(value, expr.dataType)),
             childrenSubExprs.toSeq)
-          exprs.foreach(localSubExprEliminationExprs.put(_, state))
+          localSubExprEliminationExprs.put(ExpressionEquals(expr), state)
         }
         (localSubExprEliminationExprs, exprCodesNeedEvaluate)
       } else {
@@ -1217,9 +1217,8 @@ class CodegenContext extends Logging {
 
     // Get all the expressions that appear at least twice and set up the state for subexpression
     // elimination.
-    val commonExprs = equivalentExpressions.getAllEquivalentExprs(1)
-    commonExprs.foreach { e =>
-      val expr = e.head
+    val commonExprs = equivalentExpressions.getCommonSubexpressions
+    commonExprs.foreach { expr =>
       val fnName = freshName("subExpr")
       val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull")
       val value = addMutableState(javaType(expr.dataType), "subExprValue")
@@ -1255,7 +1254,7 @@ class CodegenContext extends Logging {
         ExprCode(code"$subExprCode",
           JavaCode.isNullGlobal(isNull),
           JavaCode.global(value, expr.dataType)))
-      subExprEliminationExprs ++= e.map(_ -> state).toMap
+      subExprEliminationExprs += ExpressionEquals(expr) -> state
     }
   }
 
@@ -1834,7 +1833,7 @@ object CodeGenerator extends Logging {
   def getLocalInputVariableValues(
       ctx: CodegenContext,
       expr: Expression,
-      subExprs: Map[Expression, SubExprEliminationState] = Map.empty)
+      subExprs: Map[ExpressionEquals, SubExprEliminationState] = Map.empty)
       : (Set[VariableValue], Set[ExprCode]) = {
     val argSet = mutable.Set[VariableValue]()
     val exprCodesNeedEvaluate = mutable.Set[ExprCode]()
@@ -1852,10 +1851,6 @@ object CodeGenerator extends Logging {
     val stack = mutable.Stack[Expression](expr)
     while (stack.nonEmpty) {
       stack.pop() match {
-        case e if subExprs.contains(e) =>
-          collectLocalVariable(subExprs(e).eval.value)
-          collectLocalVariable(subExprs(e).eval.isNull)
-
         case ref: BoundReference if ctx.currentVars != null &&
             ctx.currentVars(ref.ordinal) != null =>
           val exprCode = ctx.currentVars(ref.ordinal)
@@ -1868,7 +1863,13 @@ object CodeGenerator extends Logging {
           collectLocalVariable(exprCode.isNull)
 
         case e =>
-          stack.pushAll(e.children)
+          subExprs.get(ExpressionEquals(e)) match {
+            case Some(state) =>
+              collectLocalVariable(state.eval.value)
+              collectLocalVariable(state.eval.isNull)
+            case None =>
+              stack.pushAll(e.children)
+          }
       }
     }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index fffb482..37c7229 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -325,11 +325,11 @@ object PhysicalAggregation {
           case ae: AggregateExpression =>
             // The final aggregation buffer's attributes will be `finalAggregationAttributes`,
             // so replace each aggregate expression by its corresponding attribute in the set:
-            equivalentAggregateExpressions.getEquivalentExprs(ae).headOption
+            equivalentAggregateExpressions.getExprState(ae).map(_.expr)
               .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute
             // Similar to AggregateExpression
           case ue: PythonUDF if PythonUDF.isGroupedAggPandasUDF(ue) =>
-            equivalentAggregateExpressions.getEquivalentExprs(ue).headOption
+            equivalentAggregateExpressions.getExprState(ue).map(_.expr)
               .getOrElse(ue).asInstanceOf[PythonUDF].resultAttribute
           case expression if !expression.foldable =>
             // Since we're using `namedGroupingAttributes` to extract the grouping key
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index b100554..07e045c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -457,6 +457,8 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
       Seq.range(0, 100).map(x => Literal(x.toLong))) == 201)
   }
 
+  private def wrap(expr: Expression): ExpressionEquals = ExpressionEquals(expr)
+
   test("SPARK-23760: CodegenContext.withSubExprEliminationExprs should save/restore correctly") {
 
     val ref = BoundReference(0, IntegerType, true)
@@ -472,19 +474,19 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
       val ctx = new CodegenContext
       val e = ref.genCode(ctx)
       // before
-      ctx.subExprEliminationExprs += ref -> SubExprEliminationState(
+      ctx.subExprEliminationExprs += wrap(ref) -> SubExprEliminationState(
         ExprCode(EmptyBlock, e.isNull, e.value))
-      assert(ctx.subExprEliminationExprs.contains(ref))
+      assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
       // call withSubExprEliminationExprs
-      ctx.withSubExprEliminationExprs(Map(add1 -> dummy)) {
-        assert(ctx.subExprEliminationExprs.contains(add1))
-        assert(!ctx.subExprEliminationExprs.contains(ref))
+      ctx.withSubExprEliminationExprs(Map(wrap(add1) -> dummy)) {
+        assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
+        assert(!ctx.subExprEliminationExprs.contains(wrap(ref)))
         Seq.empty
       }
       // after
       assert(ctx.subExprEliminationExprs.nonEmpty)
-      assert(ctx.subExprEliminationExprs.contains(ref))
-      assert(!ctx.subExprEliminationExprs.contains(add1))
+      assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
+      assert(!ctx.subExprEliminationExprs.contains(wrap(add1)))
     }
 
     // emulate an actual codegen workload
@@ -492,17 +494,17 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
       val ctx = new CodegenContext
       // before
       ctx.generateExpressions(Seq(add2, add1), doSubexpressionElimination = true) // trigger CSE
-      assert(ctx.subExprEliminationExprs.contains(add1))
+      assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
       // call withSubExprEliminationExprs
-      ctx.withSubExprEliminationExprs(Map(ref -> dummy)) {
-        assert(ctx.subExprEliminationExprs.contains(ref))
-        assert(!ctx.subExprEliminationExprs.contains(add1))
+      ctx.withSubExprEliminationExprs(Map(wrap(ref) -> dummy)) {
+        assert(ctx.subExprEliminationExprs.contains(wrap(ref)))
+        assert(!ctx.subExprEliminationExprs.contains(wrap(add1)))
         Seq.empty
       }
       // after
       assert(ctx.subExprEliminationExprs.nonEmpty)
-      assert(ctx.subExprEliminationExprs.contains(add1))
-      assert(!ctx.subExprEliminationExprs.contains(ref))
+      assert(ctx.subExprEliminationExprs.contains(wrap(add1)))
+      assert(!ctx.subExprEliminationExprs.contains(wrap(ref)))
     }
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
index 0c65737..6fc9d04 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala
@@ -47,35 +47,32 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
 
   test("Expression Equivalence - basic") {
     val equivalence = new EquivalentExpressions
-    assert(equivalence.getAllEquivalentExprs().isEmpty)
+    assert(equivalence.getAllExprStates().isEmpty)
 
     val oneA = Literal(1)
     val oneB = Literal(1)
     val twoA = Literal(2)
     var twoB = Literal(2)
 
-    assert(equivalence.getEquivalentExprs(oneA).isEmpty)
-    assert(equivalence.getEquivalentExprs(twoA).isEmpty)
+    assert(equivalence.getExprState(oneA).isEmpty)
+    assert(equivalence.getExprState(twoA).isEmpty)
 
     // Add oneA and test if it is returned. Since it is a group of one, it does not.
     assert(!equivalence.addExpr(oneA))
-    assert(equivalence.getEquivalentExprs(oneA).size == 1)
-    assert(equivalence.getEquivalentExprs(twoA).isEmpty)
-    assert(equivalence.addExpr((oneA)))
-    assert(equivalence.getEquivalentExprs(oneA).size == 2)
+    assert(equivalence.getExprState(oneA).get.useCount == 1)
+    assert(equivalence.getExprState(twoA).isEmpty)
+    assert(equivalence.addExpr(oneA))
+    assert(equivalence.getExprState(oneA).get.useCount == 2)
 
     // Add B and make sure they can see each other.
     assert(equivalence.addExpr(oneB))
     // Use exists and reference equality because of how equals is defined.
-    assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneB))
-    assert(equivalence.getEquivalentExprs(oneA).exists(_ eq oneA))
-    assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneA))
-    assert(equivalence.getEquivalentExprs(oneB).exists(_ eq oneB))
-    assert(equivalence.getEquivalentExprs(twoA).isEmpty)
-    assert(equivalence.getAllEquivalentExprs().size == 1)
-    assert(equivalence.getAllEquivalentExprs().head.size == 3)
-    assert(equivalence.getAllEquivalentExprs().head.contains(oneA))
-    assert(equivalence.getAllEquivalentExprs().head.contains(oneB))
+    assert(equivalence.getExprState(oneA).exists(_.expr eq oneA))
+    assert(equivalence.getExprState(oneB).exists(_.expr eq oneA))
+    assert(equivalence.getExprState(twoA).isEmpty)
+    assert(equivalence.getAllExprStates().size == 1)
+    assert(equivalence.getAllExprStates().head.useCount == 3)
+    assert(equivalence.getAllExprStates().head.expr eq oneA)
 
     val add1 = Add(oneA, oneB)
     val add2 = Add(oneA, oneB)
@@ -83,10 +80,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence.addExpr(add1)
     equivalence.addExpr(add2)
 
-    assert(equivalence.getAllEquivalentExprs().size == 2)
-    assert(equivalence.getEquivalentExprs(add2).exists(_ eq add1))
-    assert(equivalence.getEquivalentExprs(add2).size == 2)
-    assert(equivalence.getEquivalentExprs(add1).exists(_ eq add2))
+    assert(equivalence.getAllExprStates().size == 2)
+    assert(equivalence.getExprState(add1).exists(_.expr eq add1))
+    assert(equivalence.getExprState(add2).get.useCount == 2)
+    assert(equivalence.getExprState(add2).exists(_.expr eq add1))
   }
 
   test("Expression Equivalence - Trees") {
@@ -103,8 +100,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence.addExprTree(add2)
 
     // Should only have one equivalence for `one + two`
-    assert(equivalence.getAllEquivalentExprs(1).size == 1)
-    assert(equivalence.getAllEquivalentExprs(1).head.size == 4)
+    assert(equivalence.getAllExprStates(1).size == 1)
+    assert(equivalence.getAllExprStates(1).head.useCount == 4)
 
     // Set up the expressions
     //   one * two,
@@ -122,11 +119,11 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence.addExprTree(sum)
 
     // (one * two), (one * two) * (one * two) and sqrt( (one * two) * (one * two) ) should be found
-    assert(equivalence.getAllEquivalentExprs(1).size == 3)
-    assert(equivalence.getEquivalentExprs(mul).size == 3)
-    assert(equivalence.getEquivalentExprs(mul2).size == 3)
-    assert(equivalence.getEquivalentExprs(sqrt).size == 2)
-    assert(equivalence.getEquivalentExprs(sum).size == 1)
+    assert(equivalence.getAllExprStates(1).size == 3)
+    assert(equivalence.getExprState(mul).get.useCount == 3)
+    assert(equivalence.getExprState(mul2).get.useCount == 3)
+    assert(equivalence.getExprState(sqrt).get.useCount == 2)
+    assert(equivalence.getExprState(sum).get.useCount == 1)
   }
 
   test("Expression equivalence - non deterministic") {
@@ -134,7 +131,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     val equivalence = new EquivalentExpressions
     equivalence.addExpr(sum)
     equivalence.addExpr(sum)
-    assert(equivalence.getAllEquivalentExprs().isEmpty)
+    assert(equivalence.getAllExprStates().isEmpty)
   }
 
   test("Children of CodegenFallback") {
@@ -146,8 +143,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     val equivalence = new EquivalentExpressions
     equivalence.addExprTree(add)
     // the `two` inside `fallback` should not be added
-    assert(equivalence.getAllEquivalentExprs(1).size == 0)
-    assert(equivalence.getAllEquivalentExprs().count(_.size == 1) == 3) // add, two, explode
+    assert(equivalence.getAllExprStates(1).size == 0)
+    assert(equivalence.getAllExprStates().count(_.useCount == 1) == 3) // add, two, explode
   }
 
   test("Children of conditional expressions: If") {
@@ -159,35 +156,34 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence1.addExprTree(ifExpr1)
 
     // `add` is in both two branches of `If` and predicate.
-    assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
-    assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add, add))
+    assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
+    assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add)
     // one-time expressions: only ifExpr and its predicate expression
-    assert(equivalence1.getAllEquivalentExprs().count(_.size == 1) == 2)
-    assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
-    assert(equivalence1.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(condition)))
+    assert(equivalence1.getAllExprStates().count(_.useCount == 1) == 2)
+    assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1))
+    assert(equivalence1.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq condition))
 
     // Repeated `add` is only in one branch, so we don't count it.
     val ifExpr2 = If(condition, Add(Literal(1), Literal(3)), Add(add, add))
     val equivalence2 = new EquivalentExpressions
     equivalence2.addExprTree(ifExpr2)
 
-    assert(equivalence2.getAllEquivalentExprs(1).size == 0)
-    assert(equivalence2.getAllEquivalentExprs().count(_.size == 1) == 3)
+    assert(equivalence2.getAllExprStates(1).isEmpty)
+    assert(equivalence2.getAllExprStates().count(_.useCount == 1) == 3)
 
     val ifExpr3 = If(condition, ifExpr1, ifExpr1)
     val equivalence3 = new EquivalentExpressions
     equivalence3.addExprTree(ifExpr3)
 
     // `add`: 2, `condition`: 2
-    assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 2)
-    assert(equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(add, add)))
-    assert(
-      equivalence3.getAllEquivalentExprs().filter(_.size == 2).contains(Seq(condition, condition)))
+    assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 2)
+    assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq condition))
+    assert(equivalence3.getAllExprStates().filter(_.useCount == 2).exists(_.expr eq add))
 
     // `ifExpr1`, `ifExpr3`
-    assert(equivalence3.getAllEquivalentExprs().count(_.size == 1) == 2)
-    assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr1)))
-    assert(equivalence3.getAllEquivalentExprs().filter(_.size == 1).contains(Seq(ifExpr3)))
+    assert(equivalence3.getAllExprStates().count(_.useCount == 1) == 2)
+    assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr1))
+    assert(equivalence3.getAllExprStates().filter(_.useCount == 1).exists(_.expr eq ifExpr3))
   }
 
   test("Children of conditional expressions: CaseWhen") {
@@ -202,8 +198,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence1.addExprTree(caseWhenExpr1)
 
     // `add2` is repeatedly in all conditions.
-    assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
-    assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
+    assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
+    assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2)
 
     val conditions2 = (GreaterThan(add1, Literal(3)), add1) ::
       (GreaterThan(add2, Literal(4)), add1) ::
@@ -214,8 +210,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence2.addExprTree(caseWhenExpr2)
 
     // `add1` is repeatedly in all branch values, and first predicate.
-    assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 1)
-    assert(equivalence2.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add1, add1))
+    assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 1)
+    assert(equivalence2.getAllExprStates().filter(_.useCount == 2).head.expr eq add1)
 
     // Negative case. `add1` or `add2` is not commonly used in all predicates/branch values.
     val conditions3 = (GreaterThan(add1, Literal(3)), add2) ::
@@ -225,7 +221,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     val caseWhenExpr3 = CaseWhen(conditions3, None)
     val equivalence3 = new EquivalentExpressions
     equivalence3.addExprTree(caseWhenExpr3)
-    assert(equivalence3.getAllEquivalentExprs().count(_.size == 2) == 0)
+    assert(equivalence3.getAllExprStates().count(_.useCount == 2) == 0)
   }
 
   test("Children of conditional expressions: Coalesce") {
@@ -240,8 +236,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence1.addExprTree(coalesceExpr1)
 
     // `add2` is repeatedly in all conditions.
-    assert(equivalence1.getAllEquivalentExprs().count(_.size == 2) == 1)
-    assert(equivalence1.getAllEquivalentExprs().filter(_.size == 2).head == Seq(add2, add2))
+    assert(equivalence1.getAllExprStates().count(_.useCount == 2) == 1)
+    assert(equivalence1.getAllExprStates().filter(_.useCount == 2).head.expr eq add2)
 
     // Negative case. `add1` and `add2` both are not used in all branches.
     val conditions2 = GreaterThan(add1, Literal(3)) ::
@@ -252,7 +248,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     val equivalence2 = new EquivalentExpressions
     equivalence2.addExprTree(coalesceExpr2)
 
-    assert(equivalence2.getAllEquivalentExprs().count(_.size == 2) == 0)
+    assert(equivalence2.getAllExprStates().count(_.useCount == 2) == 0)
   }
 
   test("SPARK-34723: Correct parameter type for subexpression elimination under whole-stage") {
@@ -321,9 +317,10 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     val equivalence = new EquivalentExpressions
     equivalence.addExprTree(caseWhenExpr)
 
-    val commonExprs = equivalence.getAllEquivalentExprs(1)
+    val commonExprs = equivalence.getAllExprStates(1)
     assert(commonExprs.size == 1)
-    assert(commonExprs.head === Seq(add3, add3))
+    assert(commonExprs.head.useCount == 2)
+    assert(commonExprs.head.expr eq add3)
   }
 
   test("SPARK-35439: Children subexpr should come first than parent subexpr") {
@@ -332,27 +329,29 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     val equivalence1 = new EquivalentExpressions
 
     equivalence1.addExprTree(add)
-    assert(equivalence1.getAllEquivalentExprs().head === Seq(add))
+    assert(equivalence1.getAllExprStates().head.expr eq add)
 
     equivalence1.addExprTree(Add(Literal(3), add))
-    assert(equivalence1.getAllEquivalentExprs() ===
-      Seq(Seq(add, add), Seq(Add(Literal(3), add))))
+    assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 1))
+    assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
 
     equivalence1.addExprTree(Add(Literal(3), add))
-    assert(equivalence1.getAllEquivalentExprs() ===
-      Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
+    assert(equivalence1.getAllExprStates().map(_.useCount) === Seq(2, 2))
+    assert(equivalence1.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
 
     val equivalence2 = new EquivalentExpressions
 
     equivalence2.addExprTree(Add(Literal(3), add))
-    assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add), Seq(Add(Literal(3), add))))
+    assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(1, 1))
+    assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
 
     equivalence2.addExprTree(add)
-    assert(equivalence2.getAllEquivalentExprs() === Seq(Seq(add, add), Seq(Add(Literal(3), add))))
+    assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 1))
+    assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
 
     equivalence2.addExprTree(Add(Literal(3), add))
-    assert(equivalence2.getAllEquivalentExprs() ===
-      Seq(Seq(add, add), Seq(Add(Literal(3), add), Add(Literal(3), add))))
+    assert(equivalence2.getAllExprStates().map(_.useCount) === Seq(2, 2))
+    assert(equivalence2.getAllExprStates().map(_.expr) === Seq(add, Add(Literal(3), add)))
   }
 
   test("SPARK-35499: Subexpressions should only be extracted from CaseWhen values with an "
@@ -368,28 +367,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     equivalence.addExprTree(caseWhenExpr)
 
     // `add1` is not in the elseValue, so we can't extract it from the branches
-    assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0)
-  }
-
-  test("SPARK-35439: sort exprs with ExpressionContainmentOrdering") {
-    val exprOrdering = new ExpressionContainmentOrdering
-
-    val add1 = Add(Literal(1), Literal(2))
-    val add2 = Add(Literal(2), Literal(3))
-
-    // Non parent-child expressions. Don't sort on them.
-    val exprs = Seq(add2, add1, add2, add1, add2, add1)
-    assert(exprs.sorted(exprOrdering) === exprs)
-
-    val conditions = (GreaterThan(add1, Literal(3)), add1) ::
-      (GreaterThan(add2, Literal(4)), add1) ::
-      (GreaterThan(add2, Literal(5)), add1) :: Nil
-
-    // `caseWhenExpr` contains add1, add2.
-    val caseWhenExpr = CaseWhen(conditions, None)
-    val exprs2 = Seq(caseWhenExpr, add2, add1, add2, add1, add2, add1)
-    assert(exprs2.sorted(exprOrdering) ===
-      Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr))
+    assert(equivalence.getAllExprStates().count(_.useCount == 2) == 0)
   }
 
   test("SPARK-35829: SubExprEliminationState keeps children sub exprs") {
@@ -400,8 +378,8 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
     val ctx = new CodegenContext()
     val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(exprs)
 
-    val add2State = subExprs.states(add2)
-    val add1State = subExprs.states(add1)
+    val add2State = subExprs.states(ExpressionEquals(add2))
+    val add1State = subExprs.states(ExpressionEquals(add1))
     assert(add2State.children.contains(add1State))
 
     subExprs.states.values.foreach { state =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index c97c213..da310b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -257,7 +257,7 @@ case class HashAggregateExec(
       aggNames: Seq[String],
       aggBufferUpdatingExprs: Seq[Seq[Expression]],
       aggCodeBlocks: Seq[Block],
-      subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
+      subExprs: Map[ExpressionEquals, SubExprEliminationState]): Option[Seq[String]] = {
     val exprValsInSubExprs = subExprs.flatMap { case (_, s) =>
       s.eval.value :: s.eval.isNull :: Nil
     }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org