You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/12/24 22:44:48 UTC

[spark] branch master updated: [SPARK-30027][SQL] Support codegen for aggregate filters in HashAggregateExec

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 65a9ac2  [SPARK-30027][SQL] Support codegen for aggregate filters in HashAggregateExec
65a9ac2 is described below

commit 65a9ac2ff4d902976bf3ef89d1d3e29c1e6d5414
Author: Takeshi Yamamuro <ya...@apache.org>
AuthorDate: Thu Dec 24 14:44:16 2020 -0800

    [SPARK-30027][SQL] Support codegen for aggregate filters in HashAggregateExec
    
    ### What changes were proposed in this pull request?
    
    This pr intends to support code generation for `HashAggregateExec` with filters.
    
    Quick benchmark results:
    ```
    $ ./bin/spark-shell --master=local[1] --conf spark.driver.memory=8g --conf spark.sql.shuffle.partitions=1 -v
    
    scala> spark.range(100000000).selectExpr("id % 3 as k1", "id % 5 as k2", "rand() as v1", "rand() as v2").write.saveAsTable("t")
    scala> sql("SELECT k1, k2, AVG(v1) FILTER (WHERE v2 > 0.5) FROM t GROUP BY k1, k2").write.format("noop").mode("overwrite").save()
    
    >> Before this PR
    Elapsed time: 16.170697619s
    
    >> After this PR
    Elapsed time: 6.7825313s
    ```
    
    The query above is compiled into code below;
    
    ```
    ...
    /* 285 */   private void agg_doAggregate_avg_0(boolean agg_exprIsNull_2_0, org.apache.spark.sql.catalyst.InternalRow agg_unsafeRowAggBuffer_0, double agg_expr_2_0) throws java.io.IOException {
    /* 286 */     // evaluate aggregate function for avg
    /* 287 */     boolean agg_isNull_10 = true;
    /* 288 */     double agg_value_12 = -1.0;
    /* 289 */     boolean agg_isNull_11 = agg_unsafeRowAggBuffer_0.isNullAt(0);
    /* 290 */     double agg_value_13 = agg_isNull_11 ?
    /* 291 */     -1.0 : (agg_unsafeRowAggBuffer_0.getDouble(0));
    /* 292 */     if (!agg_isNull_11) {
    /* 293 */       agg_agg_isNull_12_0 = true;
    /* 294 */       double agg_value_14 = -1.0;
    /* 295 */       do {
    /* 296 */         if (!agg_exprIsNull_2_0) {
    /* 297 */           agg_agg_isNull_12_0 = false;
    /* 298 */           agg_value_14 = agg_expr_2_0;
    /* 299 */           continue;
    /* 300 */         }
    /* 301 */
    /* 302 */         if (!false) {
    /* 303 */           agg_agg_isNull_12_0 = false;
    /* 304 */           agg_value_14 = 0.0D;
    /* 305 */           continue;
    /* 306 */         }
    /* 307 */
    /* 308 */       } while (false);
    /* 309 */
    /* 310 */       agg_isNull_10 = false; // resultCode could change nullability.
    /* 311 */
    /* 312 */       agg_value_12 = agg_value_13 + agg_value_14;
    /* 313 */
    /* 314 */     }
    /* 315 */     boolean agg_isNull_15 = false;
    /* 316 */     long agg_value_17 = -1L;
    /* 317 */     if (!false && agg_exprIsNull_2_0) {
    /* 318 */       boolean agg_isNull_18 = agg_unsafeRowAggBuffer_0.isNullAt(1);
    /* 319 */       long agg_value_20 = agg_isNull_18 ?
    /* 320 */       -1L : (agg_unsafeRowAggBuffer_0.getLong(1));
    /* 321 */       agg_isNull_15 = agg_isNull_18;
    /* 322 */       agg_value_17 = agg_value_20;
    /* 323 */     } else {
    /* 324 */       boolean agg_isNull_19 = true;
    /* 325 */       long agg_value_21 = -1L;
    /* 326 */       boolean agg_isNull_20 = agg_unsafeRowAggBuffer_0.isNullAt(1);
    /* 327 */       long agg_value_22 = agg_isNull_20 ?
    /* 328 */       -1L : (agg_unsafeRowAggBuffer_0.getLong(1));
    /* 329 */       if (!agg_isNull_20) {
    /* 330 */         agg_isNull_19 = false; // resultCode could change nullability.
    /* 331 */
    /* 332 */         agg_value_21 = agg_value_22 + 1L;
    /* 333 */
    /* 334 */       }
    /* 335 */       agg_isNull_15 = agg_isNull_19;
    /* 336 */       agg_value_17 = agg_value_21;
    /* 337 */     }
    /* 338 */     // update unsafe row buffer
    /* 339 */     if (!agg_isNull_10) {
    /* 340 */       agg_unsafeRowAggBuffer_0.setDouble(0, agg_value_12);
    /* 341 */     } else {
    /* 342 */       agg_unsafeRowAggBuffer_0.setNullAt(0);
    /* 343 */     }
    /* 344 */
    /* 345 */     if (!agg_isNull_15) {
    /* 346 */       agg_unsafeRowAggBuffer_0.setLong(1, agg_value_17);
    /* 347 */     } else {
    /* 348 */       agg_unsafeRowAggBuffer_0.setNullAt(1);
    /* 349 */     }
    /* 350 */   }
    ...
    ```
    
    ### Why are the changes needed?
    
    For high performance.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #27019 from maropu/AggregateFilterCodegen.
    
    Authored-by: Takeshi Yamamuro <ya...@apache.org>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../sql/catalyst/expressions/predicates.scala      |  18 +++
 .../execution/aggregate/HashAggregateExec.scala    | 100 ++++++++--------
 .../sql/execution/basicPhysicalOperators.scala     | 130 ++++++++++++---------
 .../resources/sql-tests/inputs/group-by-filter.sql |   5 +-
 .../resources/sql-tests/results/explain.sql.out    |   4 +-
 5 files changed, 151 insertions(+), 106 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 250d3fe..c61d247 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -242,6 +242,24 @@ trait PredicateHelper extends AliasHelper with Logging {
         None
       }
   }
+
+  // If one expression and its children are null intolerant, it is null intolerant.
+  protected def isNullIntolerant(expr: Expression): Boolean = expr match {
+    case e: NullIntolerant => e.children.forall(isNullIntolerant)
+    case _ => false
+  }
+
+  protected def outputWithNullability(
+      output: Seq[Attribute],
+      nonNullAttrExprIds: Seq[ExprId]): Seq[Attribute] = {
+    output.map { a =>
+      if (a.nullable && nonNullAttrExprIds.contains(a.exprId)) {
+        a.withNullability(false)
+      } else {
+        a
+      }
+    }
+  }
 }
 
 @ExpressionDescription(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 52d0450..cdad9de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -53,7 +53,8 @@ case class HashAggregateExec(
     resultExpressions: Seq[NamedExpression],
     child: SparkPlan)
   extends BaseAggregateExec
-  with BlockingOperatorWithCodegen {
+  with BlockingOperatorWithCodegen
+  with GeneratePredicateHelper {
 
   require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
 
@@ -131,10 +132,8 @@ case class HashAggregateExec(
   override def usedInputs: AttributeSet = inputSet
 
   override def supportCodegen: Boolean = {
-    // ImperativeAggregate and filter predicate are not supported right now
-    // TODO: SPARK-30027 Support codegen for filter exprs in HashAggregateExec
-    !(aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) ||
-        aggregateExpressions.exists(_.filter.isDefined))
+    // ImperativeAggregate are not supported right now
+    !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
   }
 
   override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -254,7 +253,7 @@ case class HashAggregateExec(
       aggNames: Seq[String],
       aggBufferUpdatingExprs: Seq[Seq[Expression]],
       aggCodeBlocks: Seq[Block],
-      subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
+      subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
     val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
     if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
       // `SimpleExprValue`s cannot be used as an input variable for split functions, so
@@ -293,7 +292,7 @@ case class HashAggregateExec(
           val inputVariables = args.map(_.variableName).mkString(", ")
           s"$doAggFuncName($inputVariables);"
         }
-        Some(splitCodes.mkString("\n").trim)
+        Some(splitCodes)
       } else {
         val errMsg = "Failed to split aggregate code into small functions because the parameter " +
           "length of at least one split function went over the JVM limit: " +
@@ -308,6 +307,39 @@ case class HashAggregateExec(
     }
   }
 
+  private def generateEvalCodeForAggFuncs(
+      ctx: CodegenContext,
+      input: Seq[ExprCode],
+      inputAttrs: Seq[Attribute],
+      boundUpdateExprs: Seq[Seq[Expression]],
+      aggNames: Seq[String],
+      aggCodeBlocks: Seq[Block],
+      subExprs: SubExprCodes): String = {
+    val aggCodes = if (conf.codegenSplitAggregateFunc &&
+      aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
+      val maybeSplitCodes = splitAggregateExpressions(
+        ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
+
+      maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
+    } else {
+      aggCodeBlocks.map(_.code)
+    }
+
+    aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
+      case (aggCode, (Partial | Complete, Some(condition))) =>
+        // Note: wrap in "do { } while(false);", so the generated checks can jump out
+        // with "continue;"
+        s"""
+           |do {
+           |  ${generatePredicateCode(ctx, condition, inputAttrs, input)}
+           |  $aggCode
+           |} while(false);
+         """.stripMargin
+      case (aggCode, _) =>
+        aggCode
+    }.mkString("\n")
+  }
+
   private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     // only have DeclarativeAggregate
     val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
@@ -354,24 +386,14 @@ case class HashAggregateExec(
        """.stripMargin
     }
 
-    val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
-        aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
-      val maybeSplitCode = splitAggregateExpressions(
-        ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
-      maybeSplitCode.getOrElse {
-        aggCodeBlocks.fold(EmptyBlock)(_ + _).code
-      }
-    } else {
-      aggCodeBlocks.fold(EmptyBlock)(_ + _).code
-    }
-
+    val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+      ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
     s"""
        |// do aggregate
        |// common sub-expressions
        |$effectiveCodes
        |// evaluate aggregate functions and update aggregation buffers
-       |$codeToEvalAggFunc
+       |$codeToEvalAggFuncs
      """.stripMargin
   }
 
@@ -908,7 +930,7 @@ case class HashAggregateExec(
       }
     }
 
-    val inputAttr = aggregateBufferAttributes ++ inputAttributes
+    val inputAttrs = aggregateBufferAttributes ++ inputAttributes
     // Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
     // generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
     // generating input columns, we use `currentVars`.
@@ -930,7 +952,7 @@ case class HashAggregateExec(
     val updateRowInRegularHashMap: String = {
       ctx.INPUT_ROW = unsafeRowBuffer
       val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
-        bindReferences(updateExprsForOneFunc, inputAttr)
+        bindReferences(updateExprsForOneFunc, inputAttrs)
       }
       val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
       val effectiveCodes = subExprs.codes.mkString("\n")
@@ -961,23 +983,13 @@ case class HashAggregateExec(
          """.stripMargin
       }
 
-      val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
-          aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
-        val maybeSplitCode = splitAggregateExpressions(
-          ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
-        maybeSplitCode.getOrElse {
-          aggCodeBlocks.fold(EmptyBlock)(_ + _).code
-        }
-      } else {
-        aggCodeBlocks.fold(EmptyBlock)(_ + _).code
-      }
-
+      val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+        ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
       s"""
          |// common sub-expressions
          |$effectiveCodes
          |// evaluate aggregate functions and update aggregation buffers
-         |$codeToEvalAggFunc
+         |$codeToEvalAggFuncs
        """.stripMargin
     }
 
@@ -986,7 +998,7 @@ case class HashAggregateExec(
         if (isVectorizedHashMapEnabled) {
           ctx.INPUT_ROW = fastRowBuffer
           val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
-            bindReferences(updateExprsForOneFunc, inputAttr)
+            bindReferences(updateExprsForOneFunc, inputAttrs)
           }
           val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
           val effectiveCodes = subExprs.codes.mkString("\n")
@@ -1016,18 +1028,8 @@ case class HashAggregateExec(
              """.stripMargin
           }
 
-
-          val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
-              aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
-            val maybeSplitCode = splitAggregateExpressions(
-              ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
-            maybeSplitCode.getOrElse {
-              aggCodeBlocks.fold(EmptyBlock)(_ + _).code
-            }
-          } else {
-            aggCodeBlocks.fold(EmptyBlock)(_ + _).code
-          }
+          val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+            ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
 
           // If vectorized fast hash map is on, we first generate code to update row
           // in vectorized fast hash map, if the previous loop up hit vectorized fast hash map.
@@ -1037,7 +1039,7 @@ case class HashAggregateExec(
              |  // common sub-expressions
              |  $effectiveCodes
              |  // evaluate aggregate functions and update aggregation buffers
-             |  $codeToEvalAggFunc
+             |  $codeToEvalAggFuncs
              |} else {
              |  $updateRowInRegularHashMap
              |}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index d74d0bf..abd3360 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -109,59 +109,39 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
   }
 }
 
-/** Physical plan for Filter. */
-case class FilterExec(condition: Expression, child: SparkPlan)
-  extends UnaryExecNode with CodegenSupport with PredicateHelper {
-
-  // Split out all the IsNotNulls from condition.
-  private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
-    case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
-    case _ => false
-  }
-
-  // If one expression and its children are null intolerant, it is null intolerant.
-  private def isNullIntolerant(expr: Expression): Boolean = expr match {
-    case e: NullIntolerant => e.children.forall(isNullIntolerant)
-    case _ => false
-  }
-
-  // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
-  private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
-
-  // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
-  // all the variables at the beginning to take advantage of short circuiting.
-  override def usedInputs: AttributeSet = AttributeSet.empty
-
-  override def output: Seq[Attribute] = {
-    child.output.map { a =>
-      if (a.nullable && notNullAttributes.contains(a.exprId)) {
-        a.withNullability(false)
-      } else {
-        a
-      }
+trait GeneratePredicateHelper extends PredicateHelper {
+  self: CodegenSupport =>
+
+  protected def generatePredicateCode(
+      ctx: CodegenContext,
+      condition: Expression,
+      inputAttrs: Seq[Attribute],
+      inputExprCode: Seq[ExprCode]): String = {
+    val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
+      case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(AttributeSet(inputAttrs))
+      case _ => false
     }
-  }
-
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
-
-  override def inputRDDs(): Seq[RDD[InternalRow]] = {
-    child.asInstanceOf[CodegenSupport].inputRDDs()
-  }
-
-  protected override def doProduce(ctx: CodegenContext): String = {
-    child.asInstanceOf[CodegenSupport].produce(ctx, this)
-  }
-
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
-    val numOutput = metricTerm(ctx, "numOutputRows")
-
+    val nonNullAttrExprIds = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
+    val outputAttrs = outputWithNullability(inputAttrs, nonNullAttrExprIds)
+    generatePredicateCode(
+      ctx, inputAttrs, inputExprCode, outputAttrs, notNullPreds, otherPreds,
+      nonNullAttrExprIds)
+  }
+
+  protected def generatePredicateCode(
+      ctx: CodegenContext,
+      inputAttrs: Seq[Attribute],
+      inputExprCode: Seq[ExprCode],
+      outputAttrs: Seq[Attribute],
+      notNullPreds: Seq[Expression],
+      otherPreds: Seq[Expression],
+      nonNullAttrExprIds: Seq[ExprId]): String = {
     /**
      * Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
      */
     def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
       val bound = BindReferences.bindReference(c, attrs)
-      val evaluated = evaluateRequiredVariables(child.output, in, c.references)
+      val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references)
 
       // Generate the code for the predicate.
       val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx)
@@ -195,10 +175,10 @@ case class FilterExec(condition: Expression, child: SparkPlan)
         if (idx != -1 && !generatedIsNotNullChecks(idx)) {
           generatedIsNotNullChecks(idx) = true
           // Use the child's output. The nullability is what the child produced.
-          genPredicate(notNullPreds(idx), input, child.output)
-        } else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
+          genPredicate(notNullPreds(idx), inputExprCode, inputAttrs)
+        } else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
           extraIsNotNullAttrs += r
-          genPredicate(IsNotNull(r), input, child.output)
+          genPredicate(IsNotNull(r), inputExprCode, inputAttrs)
         } else {
           ""
         }
@@ -208,18 +188,61 @@ case class FilterExec(condition: Expression, child: SparkPlan)
       // enforced them with the IsNotNull checks above.
       s"""
          |$nullChecks
-         |${genPredicate(c, input, output)}
+         |${genPredicate(c, inputExprCode, outputAttrs)}
        """.stripMargin.trim
     }.mkString("\n")
 
     val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
       if (!generatedIsNotNullChecks(idx)) {
-        genPredicate(c, input, child.output)
+        genPredicate(c, inputExprCode, inputAttrs)
       } else {
         ""
       }
     }.mkString("\n")
 
+    s"""
+       |$generated
+       |$nullChecks
+     """.stripMargin
+  }
+}
+
+/** Physical plan for Filter. */
+case class FilterExec(condition: Expression, child: SparkPlan)
+  extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper {
+
+  // Split out all the IsNotNulls from condition.
+  private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
+    case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
+    case _ => false
+  }
+
+  // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
+  private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
+
+  // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
+  // all the variables at the beginning to take advantage of short circuiting.
+  override def usedInputs: AttributeSet = AttributeSet.empty
+
+  override def output: Seq[Attribute] = outputWithNullability(child.output, notNullAttributes)
+
+  override lazy val metrics = Map(
+    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+
+  override def inputRDDs(): Seq[RDD[InternalRow]] = {
+    child.asInstanceOf[CodegenSupport].inputRDDs()
+  }
+
+  protected override def doProduce(ctx: CodegenContext): String = {
+    child.asInstanceOf[CodegenSupport].produce(ctx, this)
+  }
+
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+    val numOutput = metricTerm(ctx, "numOutputRows")
+
+    val predicateCode = generatePredicateCode(
+      ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes)
+
     // Reset the isNull to false for the not-null columns, then the followed operators could
     // generate better code (remove dead branches).
     val resultVars = input.zipWithIndex.map { case (ev, i) =>
@@ -232,8 +255,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
     // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
     s"""
        |do {
-       |  $generated
-       |  $nullChecks
+       |  $predicateCode
        |  $numOutput.add(1);
        |  ${consume(ctx, resultVars)}
        |} while(false);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
index e4193d8..c1ccb65 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
@@ -1,4 +1,7 @@
--- Test filter clause for aggregate expression.
+-- Test filter clause for aggregate expression with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
 
 --CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=true
 --CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=false
diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
index 886b98e..a4c9238 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
@@ -878,7 +878,7 @@ struct<plan:string>
 == Physical Plan ==
 * HashAggregate (5)
 +- Exchange (4)
-   +- HashAggregate (3)
+   +- * HashAggregate (3)
       +- * ColumnarToRow (2)
          +- Scan parquet default.explain_temp1 (1)
 
@@ -892,7 +892,7 @@ ReadSchema: struct<key:int,val:int>
 (2) ColumnarToRow [codegen id : 1]
 Input [2]: [key#x, val#x]
 
-(3) HashAggregate
+(3) HashAggregate [codegen id : 1]
 Input [2]: [key#x, val#x]
 Keys: []
 Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org