You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/12/24 22:44:48 UTC
[spark] branch master updated: [SPARK-30027][SQL] Support codegen
for aggregate filters in HashAggregateExec
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 65a9ac2 [SPARK-30027][SQL] Support codegen for aggregate filters in HashAggregateExec
65a9ac2 is described below
commit 65a9ac2ff4d902976bf3ef89d1d3e29c1e6d5414
Author: Takeshi Yamamuro <ya...@apache.org>
AuthorDate: Thu Dec 24 14:44:16 2020 -0800
[SPARK-30027][SQL] Support codegen for aggregate filters in HashAggregateExec
### What changes were proposed in this pull request?
This pr intends to support code generation for `HashAggregateExec` with filters.
Quick benchmark results:
```
$ ./bin/spark-shell --master=local[1] --conf spark.driver.memory=8g --conf spark.sql.shuffle.partitions=1 -v
scala> spark.range(100000000).selectExpr("id % 3 as k1", "id % 5 as k2", "rand() as v1", "rand() as v2").write.saveAsTable("t")
scala> sql("SELECT k1, k2, AVG(v1) FILTER (WHERE v2 > 0.5) FROM t GROUP BY k1, k2").write.format("noop").mode("overwrite").save()
>> Before this PR
Elapsed time: 16.170697619s
>> After this PR
Elapsed time: 6.7825313s
```
The query above is compiled into code below;
```
...
/* 285 */ private void agg_doAggregate_avg_0(boolean agg_exprIsNull_2_0, org.apache.spark.sql.catalyst.InternalRow agg_unsafeRowAggBuffer_0, double agg_expr_2_0) throws java.io.IOException {
/* 286 */ // evaluate aggregate function for avg
/* 287 */ boolean agg_isNull_10 = true;
/* 288 */ double agg_value_12 = -1.0;
/* 289 */ boolean agg_isNull_11 = agg_unsafeRowAggBuffer_0.isNullAt(0);
/* 290 */ double agg_value_13 = agg_isNull_11 ?
/* 291 */ -1.0 : (agg_unsafeRowAggBuffer_0.getDouble(0));
/* 292 */ if (!agg_isNull_11) {
/* 293 */ agg_agg_isNull_12_0 = true;
/* 294 */ double agg_value_14 = -1.0;
/* 295 */ do {
/* 296 */ if (!agg_exprIsNull_2_0) {
/* 297 */ agg_agg_isNull_12_0 = false;
/* 298 */ agg_value_14 = agg_expr_2_0;
/* 299 */ continue;
/* 300 */ }
/* 301 */
/* 302 */ if (!false) {
/* 303 */ agg_agg_isNull_12_0 = false;
/* 304 */ agg_value_14 = 0.0D;
/* 305 */ continue;
/* 306 */ }
/* 307 */
/* 308 */ } while (false);
/* 309 */
/* 310 */ agg_isNull_10 = false; // resultCode could change nullability.
/* 311 */
/* 312 */ agg_value_12 = agg_value_13 + agg_value_14;
/* 313 */
/* 314 */ }
/* 315 */ boolean agg_isNull_15 = false;
/* 316 */ long agg_value_17 = -1L;
/* 317 */ if (!false && agg_exprIsNull_2_0) {
/* 318 */ boolean agg_isNull_18 = agg_unsafeRowAggBuffer_0.isNullAt(1);
/* 319 */ long agg_value_20 = agg_isNull_18 ?
/* 320 */ -1L : (agg_unsafeRowAggBuffer_0.getLong(1));
/* 321 */ agg_isNull_15 = agg_isNull_18;
/* 322 */ agg_value_17 = agg_value_20;
/* 323 */ } else {
/* 324 */ boolean agg_isNull_19 = true;
/* 325 */ long agg_value_21 = -1L;
/* 326 */ boolean agg_isNull_20 = agg_unsafeRowAggBuffer_0.isNullAt(1);
/* 327 */ long agg_value_22 = agg_isNull_20 ?
/* 328 */ -1L : (agg_unsafeRowAggBuffer_0.getLong(1));
/* 329 */ if (!agg_isNull_20) {
/* 330 */ agg_isNull_19 = false; // resultCode could change nullability.
/* 331 */
/* 332 */ agg_value_21 = agg_value_22 + 1L;
/* 333 */
/* 334 */ }
/* 335 */ agg_isNull_15 = agg_isNull_19;
/* 336 */ agg_value_17 = agg_value_21;
/* 337 */ }
/* 338 */ // update unsafe row buffer
/* 339 */ if (!agg_isNull_10) {
/* 340 */ agg_unsafeRowAggBuffer_0.setDouble(0, agg_value_12);
/* 341 */ } else {
/* 342 */ agg_unsafeRowAggBuffer_0.setNullAt(0);
/* 343 */ }
/* 344 */
/* 345 */ if (!agg_isNull_15) {
/* 346 */ agg_unsafeRowAggBuffer_0.setLong(1, agg_value_17);
/* 347 */ } else {
/* 348 */ agg_unsafeRowAggBuffer_0.setNullAt(1);
/* 349 */ }
/* 350 */ }
...
```
### Why are the changes needed?
For high performance.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing tests.
Closes #27019 from maropu/AggregateFilterCodegen.
Authored-by: Takeshi Yamamuro <ya...@apache.org>
Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
.../sql/catalyst/expressions/predicates.scala | 18 +++
.../execution/aggregate/HashAggregateExec.scala | 100 ++++++++--------
.../sql/execution/basicPhysicalOperators.scala | 130 ++++++++++++---------
.../resources/sql-tests/inputs/group-by-filter.sql | 5 +-
.../resources/sql-tests/results/explain.sql.out | 4 +-
5 files changed, 151 insertions(+), 106 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 250d3fe..c61d247 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -242,6 +242,24 @@ trait PredicateHelper extends AliasHelper with Logging {
None
}
}
+
+ // If one expression and its children are null intolerant, it is null intolerant.
+ protected def isNullIntolerant(expr: Expression): Boolean = expr match {
+ case e: NullIntolerant => e.children.forall(isNullIntolerant)
+ case _ => false
+ }
+
+ protected def outputWithNullability(
+ output: Seq[Attribute],
+ nonNullAttrExprIds: Seq[ExprId]): Seq[Attribute] = {
+ output.map { a =>
+ if (a.nullable && nonNullAttrExprIds.contains(a.exprId)) {
+ a.withNullability(false)
+ } else {
+ a
+ }
+ }
+ }
}
@ExpressionDescription(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 52d0450..cdad9de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -53,7 +53,8 @@ case class HashAggregateExec(
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends BaseAggregateExec
- with BlockingOperatorWithCodegen {
+ with BlockingOperatorWithCodegen
+ with GeneratePredicateHelper {
require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes))
@@ -131,10 +132,8 @@ case class HashAggregateExec(
override def usedInputs: AttributeSet = inputSet
override def supportCodegen: Boolean = {
- // ImperativeAggregate and filter predicate are not supported right now
- // TODO: SPARK-30027 Support codegen for filter exprs in HashAggregateExec
- !(aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) ||
- aggregateExpressions.exists(_.filter.isDefined))
+ // ImperativeAggregate are not supported right now
+ !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -254,7 +253,7 @@ case class HashAggregateExec(
aggNames: Seq[String],
aggBufferUpdatingExprs: Seq[Seq[Expression]],
aggCodeBlocks: Seq[Block],
- subExprs: Map[Expression, SubExprEliminationState]): Option[String] = {
+ subExprs: Map[Expression, SubExprEliminationState]): Option[Seq[String]] = {
val exprValsInSubExprs = subExprs.flatMap { case (_, s) => s.value :: s.isNull :: Nil }
if (exprValsInSubExprs.exists(_.isInstanceOf[SimpleExprValue])) {
// `SimpleExprValue`s cannot be used as an input variable for split functions, so
@@ -293,7 +292,7 @@ case class HashAggregateExec(
val inputVariables = args.map(_.variableName).mkString(", ")
s"$doAggFuncName($inputVariables);"
}
- Some(splitCodes.mkString("\n").trim)
+ Some(splitCodes)
} else {
val errMsg = "Failed to split aggregate code into small functions because the parameter " +
"length of at least one split function went over the JVM limit: " +
@@ -308,6 +307,39 @@ case class HashAggregateExec(
}
}
+ private def generateEvalCodeForAggFuncs(
+ ctx: CodegenContext,
+ input: Seq[ExprCode],
+ inputAttrs: Seq[Attribute],
+ boundUpdateExprs: Seq[Seq[Expression]],
+ aggNames: Seq[String],
+ aggCodeBlocks: Seq[Block],
+ subExprs: SubExprCodes): String = {
+ val aggCodes = if (conf.codegenSplitAggregateFunc &&
+ aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
+ val maybeSplitCodes = splitAggregateExpressions(
+ ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
+
+ maybeSplitCodes.getOrElse(aggCodeBlocks.map(_.code))
+ } else {
+ aggCodeBlocks.map(_.code)
+ }
+
+ aggCodes.zip(aggregateExpressions.map(ae => (ae.mode, ae.filter))).map {
+ case (aggCode, (Partial | Complete, Some(condition))) =>
+ // Note: wrap in "do { } while(false);", so the generated checks can jump out
+ // with "continue;"
+ s"""
+ |do {
+ | ${generatePredicateCode(ctx, condition, inputAttrs, input)}
+ | $aggCode
+ |} while(false);
+ """.stripMargin
+ case (aggCode, _) =>
+ aggCode
+ }.mkString("\n")
+ }
+
private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
@@ -354,24 +386,14 @@ case class HashAggregateExec(
""".stripMargin
}
- val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
- aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
- val maybeSplitCode = splitAggregateExpressions(
- ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
- maybeSplitCode.getOrElse {
- aggCodeBlocks.fold(EmptyBlock)(_ + _).code
- }
- } else {
- aggCodeBlocks.fold(EmptyBlock)(_ + _).code
- }
-
+ val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+ ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
s"""
|// do aggregate
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate functions and update aggregation buffers
- |$codeToEvalAggFunc
+ |$codeToEvalAggFuncs
""".stripMargin
}
@@ -908,7 +930,7 @@ case class HashAggregateExec(
}
}
- val inputAttr = aggregateBufferAttributes ++ inputAttributes
+ val inputAttrs = aggregateBufferAttributes ++ inputAttributes
// Here we set `currentVars(0)` to `currentVars(numBufferSlots)` to null, so that when
// generating code for buffer columns, we use `INPUT_ROW`(will be the buffer row), while
// generating input columns, we use `currentVars`.
@@ -930,7 +952,7 @@ case class HashAggregateExec(
val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
- bindReferences(updateExprsForOneFunc, inputAttr)
+ bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
@@ -961,23 +983,13 @@ case class HashAggregateExec(
""".stripMargin
}
- val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
- aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
- val maybeSplitCode = splitAggregateExpressions(
- ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
- maybeSplitCode.getOrElse {
- aggCodeBlocks.fold(EmptyBlock)(_ + _).code
- }
- } else {
- aggCodeBlocks.fold(EmptyBlock)(_ + _).code
- }
-
+ val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+ ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
s"""
|// common sub-expressions
|$effectiveCodes
|// evaluate aggregate functions and update aggregation buffers
- |$codeToEvalAggFunc
+ |$codeToEvalAggFuncs
""".stripMargin
}
@@ -986,7 +998,7 @@ case class HashAggregateExec(
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
val boundUpdateExprs = updateExprs.map { updateExprsForOneFunc =>
- bindReferences(updateExprsForOneFunc, inputAttr)
+ bindReferences(updateExprsForOneFunc, inputAttrs)
}
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExprs.flatten)
val effectiveCodes = subExprs.codes.mkString("\n")
@@ -1016,18 +1028,8 @@ case class HashAggregateExec(
""".stripMargin
}
-
- val codeToEvalAggFunc = if (conf.codegenSplitAggregateFunc &&
- aggCodeBlocks.map(_.length).sum > conf.methodSplitThreshold) {
- val maybeSplitCode = splitAggregateExpressions(
- ctx, aggNames, boundUpdateExprs, aggCodeBlocks, subExprs.states)
-
- maybeSplitCode.getOrElse {
- aggCodeBlocks.fold(EmptyBlock)(_ + _).code
- }
- } else {
- aggCodeBlocks.fold(EmptyBlock)(_ + _).code
- }
+ val codeToEvalAggFuncs = generateEvalCodeForAggFuncs(
+ ctx, input, inputAttrs, boundUpdateExprs, aggNames, aggCodeBlocks, subExprs)
// If vectorized fast hash map is on, we first generate code to update row
// in vectorized fast hash map, if the previous loop up hit vectorized fast hash map.
@@ -1037,7 +1039,7 @@ case class HashAggregateExec(
| // common sub-expressions
| $effectiveCodes
| // evaluate aggregate functions and update aggregation buffers
- | $codeToEvalAggFunc
+ | $codeToEvalAggFuncs
|} else {
| $updateRowInRegularHashMap
|}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index d74d0bf..abd3360 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -109,59 +109,39 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}
}
-/** Physical plan for Filter. */
-case class FilterExec(condition: Expression, child: SparkPlan)
- extends UnaryExecNode with CodegenSupport with PredicateHelper {
-
- // Split out all the IsNotNulls from condition.
- private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
- case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
- case _ => false
- }
-
- // If one expression and its children are null intolerant, it is null intolerant.
- private def isNullIntolerant(expr: Expression): Boolean = expr match {
- case e: NullIntolerant => e.children.forall(isNullIntolerant)
- case _ => false
- }
-
- // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
- private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
-
- // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
- // all the variables at the beginning to take advantage of short circuiting.
- override def usedInputs: AttributeSet = AttributeSet.empty
-
- override def output: Seq[Attribute] = {
- child.output.map { a =>
- if (a.nullable && notNullAttributes.contains(a.exprId)) {
- a.withNullability(false)
- } else {
- a
- }
+trait GeneratePredicateHelper extends PredicateHelper {
+ self: CodegenSupport =>
+
+ protected def generatePredicateCode(
+ ctx: CodegenContext,
+ condition: Expression,
+ inputAttrs: Seq[Attribute],
+ inputExprCode: Seq[ExprCode]): String = {
+ val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
+ case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(AttributeSet(inputAttrs))
+ case _ => false
}
- }
-
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
-
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- child.asInstanceOf[CodegenSupport].inputRDDs()
- }
-
- protected override def doProduce(ctx: CodegenContext): String = {
- child.asInstanceOf[CodegenSupport].produce(ctx, this)
- }
-
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
- val numOutput = metricTerm(ctx, "numOutputRows")
-
+ val nonNullAttrExprIds = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
+ val outputAttrs = outputWithNullability(inputAttrs, nonNullAttrExprIds)
+ generatePredicateCode(
+ ctx, inputAttrs, inputExprCode, outputAttrs, notNullPreds, otherPreds,
+ nonNullAttrExprIds)
+ }
+
+ protected def generatePredicateCode(
+ ctx: CodegenContext,
+ inputAttrs: Seq[Attribute],
+ inputExprCode: Seq[ExprCode],
+ outputAttrs: Seq[Attribute],
+ notNullPreds: Seq[Expression],
+ otherPreds: Seq[Expression],
+ nonNullAttrExprIds: Seq[ExprId]): String = {
/**
* Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
*/
def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
val bound = BindReferences.bindReference(c, attrs)
- val evaluated = evaluateRequiredVariables(child.output, in, c.references)
+ val evaluated = evaluateRequiredVariables(inputAttrs, in, c.references)
// Generate the code for the predicate.
val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx)
@@ -195,10 +175,10 @@ case class FilterExec(condition: Expression, child: SparkPlan)
if (idx != -1 && !generatedIsNotNullChecks(idx)) {
generatedIsNotNullChecks(idx) = true
// Use the child's output. The nullability is what the child produced.
- genPredicate(notNullPreds(idx), input, child.output)
- } else if (notNullAttributes.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
+ genPredicate(notNullPreds(idx), inputExprCode, inputAttrs)
+ } else if (nonNullAttrExprIds.contains(r.exprId) && !extraIsNotNullAttrs.contains(r)) {
extraIsNotNullAttrs += r
- genPredicate(IsNotNull(r), input, child.output)
+ genPredicate(IsNotNull(r), inputExprCode, inputAttrs)
} else {
""
}
@@ -208,18 +188,61 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// enforced them with the IsNotNull checks above.
s"""
|$nullChecks
- |${genPredicate(c, input, output)}
+ |${genPredicate(c, inputExprCode, outputAttrs)}
""".stripMargin.trim
}.mkString("\n")
val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
if (!generatedIsNotNullChecks(idx)) {
- genPredicate(c, input, child.output)
+ genPredicate(c, inputExprCode, inputAttrs)
} else {
""
}
}.mkString("\n")
+ s"""
+ |$generated
+ |$nullChecks
+ """.stripMargin
+ }
+}
+
+/** Physical plan for Filter. */
+case class FilterExec(condition: Expression, child: SparkPlan)
+ extends UnaryExecNode with CodegenSupport with GeneratePredicateHelper {
+
+ // Split out all the IsNotNulls from condition.
+ private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
+ case IsNotNull(a) => isNullIntolerant(a) && a.references.subsetOf(child.outputSet)
+ case _ => false
+ }
+
+ // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
+ private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId)
+
+ // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
+ // all the variables at the beginning to take advantage of short circuiting.
+ override def usedInputs: AttributeSet = AttributeSet.empty
+
+ override def output: Seq[Attribute] = outputWithNullability(child.output, notNullAttributes)
+
+ override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].inputRDDs()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val predicateCode = generatePredicateCode(
+ ctx, child.output, input, output, notNullPreds, otherPreds, notNullAttributes)
+
// Reset the isNull to false for the not-null columns, then the followed operators could
// generate better code (remove dead branches).
val resultVars = input.zipWithIndex.map { case (ev, i) =>
@@ -232,8 +255,7 @@ case class FilterExec(condition: Expression, child: SparkPlan)
// Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
s"""
|do {
- | $generated
- | $nullChecks
+ | $predicateCode
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
|} while(false);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
index e4193d8..c1ccb65 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-filter.sql
@@ -1,4 +1,7 @@
--- Test filter clause for aggregate expression.
+-- Test filter clause for aggregate expression with codegen on and off.
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=true
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=CODEGEN_ONLY
+--CONFIG_DIM1 spark.sql.codegen.wholeStage=false,spark.sql.codegen.factoryMode=NO_CODEGEN
--CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=true
--CONFIG_DIM1 spark.sql.optimizeNullAwareAntiJoin=false
diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
index 886b98e..a4c9238 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
@@ -878,7 +878,7 @@ struct<plan:string>
== Physical Plan ==
* HashAggregate (5)
+- Exchange (4)
- +- HashAggregate (3)
+ +- * HashAggregate (3)
+- * ColumnarToRow (2)
+- Scan parquet default.explain_temp1 (1)
@@ -892,7 +892,7 @@ ReadSchema: struct<key:int,val:int>
(2) ColumnarToRow [codegen id : 1]
Input [2]: [key#x, val#x]
-(3) HashAggregate
+(3) HashAggregate [codegen id : 1]
Input [2]: [key#x, val#x]
Keys: []
Functions [3]: [partial_count(val#x), partial_sum(cast(key#x as bigint)), partial_count(key#x) FILTER (WHERE (val#x > 1))]
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org