You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/10/11 05:20:59 UTC
[spark] branch master updated: [SPARK-36645][SQL] Aggregate
(Min/Max/Count) push down for Parquet
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 128168d [SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet
128168d is described below
commit 128168d8c4019a1e10a9f1be734868524f6a09f0
Author: Huaxin Gao <hu...@apple.com>
AuthorDate: Sun Oct 10 22:20:09 2021 -0700
[SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet
### What changes were proposed in this pull request?
Push down Min/Max/Count to Parquet with the following restrictions:
- nested types such as Array, Map or Struct will not be pushed down
- Timestamp not pushed down because INT96 sort order is undefined, Parquet doesn't return statistics for INT96
- If the aggregate column is on partition column, only Count will be pushed, Min or Max will not be pushed down because Parquet doesn't return max/min for partition column.
- If somehow the file doesn't have stats for the aggregate columns, Spark will throw Exception.
- Currently, if filter/GROUP BY is involved, Min/Max/Count will not be pushed down, but the restriction will be lifted if the filter or GROUP BY is on partition column (https://issues.apache.org/jira/browse/SPARK-36646 and https://issues.apache.org/jira/browse/SPARK-36647)
### Why are the changes needed?
Since parquet has the statistics information for min, max and count, we want to take advantage of this info and push down Min/Max/Count to parquet layer for better performance.
### Does this PR introduce _any_ user-facing change?
Yes, `SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED` was added. If sets to true, we will push down Min/Max/Count to Parquet.
### How was this patch tested?
new test suites
Closes #33639 from huaxingao/parquet_agg.
Authored-by: Huaxin Gao <hu...@apple.com>
Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
.../org/apache/spark/sql/internal/SQLConf.scala | 10 +
.../org/apache/spark/sql/types/StructType.scala | 2 +-
.../datasources/parquet/ParquetUtils.scala | 227 +++++++++
.../execution/datasources/v2/FileScanBuilder.scala | 2 +-
.../v2/parquet/ParquetPartitionReaderFactory.scala | 123 ++++-
.../datasources/v2/parquet/ParquetScan.scala | 37 +-
.../v2/parquet/ParquetScanBuilder.scala | 96 +++-
.../scala/org/apache/spark/sql/FileScanSuite.scala | 2 +-
.../parquet/ParquetAggregatePushDownSuite.scala | 518 +++++++++++++++++++++
9 files changed, 984 insertions(+), 33 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6443dfd..98aad1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -853,6 +853,14 @@ object SQLConf {
.checkValue(threshold => threshold >= 0, "The threshold must not be negative.")
.createWithDefault(10)
+ val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown")
+ .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
+ " down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" +
+ " can't be pushed down")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
.doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " +
"values will be written in Apache Parquet's fixed-length byte array format, which other " +
@@ -3660,6 +3668,8 @@ class SQLConf extends Serializable with Logging {
def parquetFilterPushDownInFilterThreshold: Int =
getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD)
+ def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED)
+
def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)
def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index c9862cb..50b197f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
def names: Array[String] = fieldNames
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
- private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
+ private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
override def equals(that: Any): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index b91d75c..1093f9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -16,11 +16,28 @@
*/
package org.apache.spark.sql.execution.datasources.parquet
+import java.util
+
+import scala.collection.mutable
+import scala.language.existentials
+
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.parquet.hadoop.ParquetFileWriter
+import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata}
+import org.apache.parquet.io.api.Binary
+import org.apache.parquet.schema.{PrimitiveType, Types}
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
+import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.datasources.PartitioningUtils
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
+import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
object ParquetUtils {
def inferSchema(
@@ -127,4 +144,214 @@ object ParquetUtils {
file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to
+ * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want
+ * to get the partial aggregates (Max/Min/Count) result using the statistics information
+ * from Parquet footer file, and then construct an InternalRow from these aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ private[sql] def createAggInternalRowFromFooter(
+ footer: ParquetMetadata,
+ filePath: String,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ isCaseSensitive: Boolean): InternalRow = {
+ val (primitiveTypes, values) = getPushedDownAggResult(
+ footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive)
+
+ val builder = Types.buildMessage
+ primitiveTypes.foreach(t => builder.addField(t))
+ val parquetSchema = builder.named("root")
+
+ val schemaConverter = new ParquetToSparkSchemaConverter
+ val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema,
+ None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
+ val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
+ primitiveTypeNames.zipWithIndex.foreach {
+ case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+ val v = values(i).asInstanceOf[Boolean]
+ converter.getConverter(i).asPrimitiveConverter.addBoolean(v)
+ case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+ val v = values(i).asInstanceOf[Integer]
+ converter.getConverter(i).asPrimitiveConverter.addInt(v)
+ case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+ val v = values(i).asInstanceOf[Long]
+ converter.getConverter(i).asPrimitiveConverter.addLong(v)
+ case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+ val v = values(i).asInstanceOf[Float]
+ converter.getConverter(i).asPrimitiveConverter.addFloat(v)
+ case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+ val v = values(i).asInstanceOf[Double]
+ converter.getConverter(i).asPrimitiveConverter.addDouble(v)
+ case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+ case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+ case (_, i) =>
+ throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
+ }
+ converter.currentRecord
+ }
+
+ /**
+ * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of
+ * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader
+ * to read data from Parquet and aggregate at Spark layer. Instead we want
+ * to get the aggregates (Max/Min/Count) result using the statistics information
+ * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results.
+ *
+ * @return Aggregate results in the format of ColumnarBatch
+ */
+ private[sql] def createAggColumnarBatchFromFooter(
+ footer: ParquetMetadata,
+ filePath: String,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ offHeap: Boolean,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ isCaseSensitive: Boolean): ColumnarBatch = {
+ val row = createAggInternalRowFromFooter(
+ footer,
+ filePath,
+ dataSchema,
+ partitionSchema,
+ aggregation,
+ aggSchema,
+ datetimeRebaseMode,
+ isCaseSensitive)
+ val converter = new RowToColumnConverter(aggSchema)
+ val columnVectors = if (offHeap) {
+ OffHeapColumnVector.allocateColumns(1, aggSchema)
+ } else {
+ OnHeapColumnVector.allocateColumns(1, aggSchema)
+ }
+ converter.convert(row, columnVectors.toArray)
+ new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
+ }
+
+ /**
+ * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics
+ * information from Parquet footer file.
+ *
+ * @return A tuple of `Array[PrimitiveType]` and Array[Any].
+ * The first element is the Parquet PrimitiveType of the aggregate column,
+ * and the second element is the aggregated value.
+ */
+ private[sql] def getPushedDownAggResult(
+ footer: ParquetMetadata,
+ filePath: String,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ isCaseSensitive: Boolean)
+ : (Array[PrimitiveType], Array[Any]) = {
+ val footerFileMetaData = footer.getFileMetaData
+ val fields = footerFileMetaData.getSchema.getFields
+ val blocks = footer.getBlocks
+ val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
+ val valuesBuilder = mutable.ArrayBuilder.make[Any]
+
+ assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down")
+ aggregation.aggregateExpressions.foreach { agg =>
+ var value: Any = None
+ var rowCount = 0L
+ var isCount = false
+ var index = 0
+ var schemaName = ""
+ blocks.forEach { block =>
+ val blockMetaData = block.getColumns
+ agg match {
+ case max: Max =>
+ val colName = max.column.fieldNames.head
+ index = dataSchema.fieldNames.toList.indexOf(colName)
+ schemaName = "max(" + colName + ")"
+ val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true)
+ if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) {
+ value = currentMax
+ }
+ case min: Min =>
+ val colName = min.column.fieldNames.head
+ index = dataSchema.fieldNames.toList.indexOf(colName)
+ schemaName = "min(" + colName + ")"
+ val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false)
+ if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) {
+ value = currentMin
+ }
+ case count: Count =>
+ schemaName = "count(" + count.column.fieldNames.head + ")"
+ rowCount += block.getRowCount
+ var isPartitionCol = false
+ if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive))
+ .toSet.contains(count.column.fieldNames.head)) {
+ isPartitionCol = true
+ }
+ isCount = true
+ if (!isPartitionCol) {
+ index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head)
+ // Count(*) includes the null values, but Count(colName) doesn't.
+ rowCount -= getNumNulls(filePath, blockMetaData, index)
+ }
+ case _: CountStar =>
+ schemaName = "count(*)"
+ rowCount += block.getRowCount
+ isCount = true
+ case _ =>
+ }
+ }
+ if (isCount) {
+ valuesBuilder += rowCount
+ primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName);
+ } else {
+ valuesBuilder += value
+ val field = fields.get(index)
+ primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName)
+ .as(field.getLogicalTypeAnnotation)
+ .length(field.asPrimitiveType.getTypeLength)
+ .named(schemaName)
+ }
+ }
+ (primitiveTypeBuilder.result, valuesBuilder.result)
+ }
+
+ /**
+ * Get the Max or Min value for ith column in the current block
+ *
+ * @return the Max or Min value
+ */
+ private def getCurrentBlockMaxOrMin(
+ filePath: String,
+ columnChunkMetaData: util.List[ColumnChunkMetaData],
+ i: Int,
+ isMax: Boolean): Any = {
+ val statistics = columnChunkMetaData.get(i).getStatistics
+ if (!statistics.hasNonNullValue) {
+ throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " +
+ s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again")
+ } else {
+ if (isMax) statistics.genericGetMax else statistics.genericGetMin
+ }
+ }
+
+ private def getNumNulls(
+ filePath: String,
+ columnChunkMetaData: util.List[ColumnChunkMetaData],
+ i: Int): Long = {
+ val statistics = columnChunkMetaData.get(i).getStatistics
+ if (!statistics.isNumNullsSet) {
+ throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" +
+ s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" +
+ s" again")
+ }
+ statistics.getNumNulls;
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
index 309f045..2dc4137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
@@ -96,6 +96,6 @@ abstract class FileScanBuilder(
private def createRequiredNameSet(): Set[String] =
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
- private val partitionNameSet: Set[String] =
+ val partitionNameSet: Set[String] =
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
index 058669b..111018b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
@@ -25,14 +25,16 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.parquet.filter2.compat.FilterCompat
import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate}
-import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS
+import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS}
import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader}
+import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator}
import org.apache.spark.sql.execution.datasources.parquet._
@@ -53,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration
* @param readDataSchema Required schema of Parquet files.
* @param partitionSchema Schema of partitions.
* @param filters Filters to be pushed down in the batch scan.
+ * @param aggregation Aggregation to be pushed down in the batch scan.
* @param parquetOptions The options of Parquet datasource that are set for the read.
*/
case class ParquetPartitionReaderFactory(
@@ -62,6 +65,7 @@ case class ParquetPartitionReaderFactory(
readDataSchema: StructType,
partitionSchema: StructType,
filters: Array[Filter],
+ aggregation: Option[Aggregation],
parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging {
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields)
@@ -80,6 +84,30 @@ case class ParquetPartitionReaderFactory(
private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead
private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead
+ private def getFooter(file: PartitionedFile): ParquetMetadata = {
+ val conf = broadcastedConf.value.value
+ val filePath = new Path(new URI(file.filePath))
+
+ if (aggregation.isEmpty) {
+ ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS)
+ } else {
+ // For aggregate push down, we will get max/min/count from footer statistics.
+ // We want to read the footer for the whole file instead of reading multiple
+ // footers for every split of the file. Basically if the start (the beginning of)
+ // the offset in PartitionedFile is 0, we will read the footer. Otherwise, it means
+ // that we have already read footer for that file, so we will skip reading again.
+ if (file.start != 0) return null
+ ParquetFooterReader.readFooter(conf, filePath, NO_FILTER)
+ }
+ }
+
+ private def getDatetimeRebaseMode(
+ footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = {
+ DataSourceUtils.datetimeRebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get,
+ datetimeRebaseModeInRead)
+ }
+
override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
@@ -87,18 +115,44 @@ case class ParquetPartitionReaderFactory(
}
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
- val reader = if (enableVectorizedReader) {
- createVectorizedReader(file)
- } else {
- createRowBaseReader(file)
- }
+ val fileReader = if (aggregation.isEmpty) {
+ val reader = if (enableVectorizedReader) {
+ createVectorizedReader(file)
+ } else {
+ createRowBaseReader(file)
+ }
+
+ new PartitionReader[InternalRow] {
+ override def next(): Boolean = reader.nextKeyValue()
- val fileReader = new PartitionReader[InternalRow] {
- override def next(): Boolean = reader.nextKeyValue()
+ override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow]
- override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow]
+ override def close(): Unit = reader.close()
+ }
+ } else {
+ new PartitionReader[InternalRow] {
+ private var hasNext = true
+ private lazy val row: InternalRow = {
+ val footer = getFooter(file)
+ if (footer != null && footer.getBlocks.size > 0) {
+ ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema,
+ partitionSchema, aggregation.get, readDataSchema,
+ getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+ } else {
+ null
+ }
+ }
+ override def next(): Boolean = {
+ hasNext && row != null
+ }
- override def close(): Unit = reader.close()
+ override def get(): InternalRow = {
+ hasNext = false
+ row
+ }
+
+ override def close(): Unit = {}
+ }
}
new PartitionReaderWithPartitionValues(fileReader, readDataSchema,
@@ -106,17 +160,45 @@ case class ParquetPartitionReaderFactory(
}
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
- val vectorizedReader = createVectorizedReader(file)
- vectorizedReader.enableReturningBatches()
+ val fileReader = if (aggregation.isEmpty) {
+ val vectorizedReader = createVectorizedReader(file)
+ vectorizedReader.enableReturningBatches()
+
+ new PartitionReader[ColumnarBatch] {
+ override def next(): Boolean = vectorizedReader.nextKeyValue()
- new PartitionReader[ColumnarBatch] {
- override def next(): Boolean = vectorizedReader.nextKeyValue()
+ override def get(): ColumnarBatch =
+ vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
- override def get(): ColumnarBatch =
- vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
+ override def close(): Unit = vectorizedReader.close()
+ }
+ } else {
+ new PartitionReader[ColumnarBatch] {
+ private var hasNext = true
+ private val row: ColumnarBatch = {
+ val footer = getFooter(file)
+ if (footer != null && footer.getBlocks.size > 0) {
+ ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema,
+ partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector,
+ getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+ } else {
+ null
+ }
+ }
+
+ override def next(): Boolean = {
+ hasNext && row != null
+ }
+
+ override def get(): ColumnarBatch = {
+ hasNext = false
+ row
+ }
- override def close(): Unit = vectorizedReader.close()
+ override def close(): Unit = {}
+ }
}
+ fileReader
}
private def buildReaderBase[T](
@@ -131,11 +213,8 @@ case class ParquetPartitionReaderFactory(
val filePath = new Path(new URI(file.filePath))
val split = new FileSplit(filePath, file.start, file.length, Array.empty[String])
- lazy val footerFileMetaData =
- ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData
- val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
- footerFileMetaData.getKeyValueMetaData.get,
- datetimeRebaseModeInRead)
+ lazy val footerFileMetaData = getFooter(file).getFileMetaData
+ val datetimeRebaseMode = getDatetimeRebaseMode(footerFileMetaData)
// Try to push down filters when filter push-down is enabled.
val pushed = if (enableParquetFilterPushDown) {
val parquetSchema = footerFileMetaData.getSchema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index e277e33..42dc287 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -24,6 +24,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
@@ -43,10 +44,17 @@ case class ParquetScan(
readPartitionSchema: StructType,
pushedFilters: Array[Filter],
options: CaseInsensitiveStringMap,
+ pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = true
+ override def readSchema(): StructType = {
+ // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder`
+ // and no need to call super.readSchema()
+ if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema()
+ }
+
override def createReaderFactory(): PartitionReaderFactory = {
val readDataSchemaAsJson = readDataSchema.json
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
@@ -86,23 +94,46 @@ case class ParquetScan(
readDataSchema,
readPartitionSchema,
pushedFilters,
+ pushedAggregate,
new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf))
}
override def equals(obj: Any): Boolean = obj match {
case p: ParquetScan =>
+ val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) {
+ equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get)
+ } else {
+ pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
+ }
super.equals(p) && dataSchema == p.dataSchema && options == p.options &&
- equivalentFilters(pushedFilters, p.pushedFilters)
+ equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual
case _ => false
}
override def hashCode(): Int = getClass.hashCode()
+ lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
+ (seqToString(pushedAggregate.get.aggregateExpressions),
+ seqToString(pushedAggregate.get.groupByColumns))
+ } else {
+ ("[]", "[]")
+ }
+
override def description(): String = {
- super.description() + ", PushedFilters: " + seqToString(pushedFilters)
+ super.description() + ", PushedFilters: " + seqToString(pushedFilters) +
+ ", PushedAggregation: " + pushedAggregationsStr +
+ ", PushedGroupBy: " + pushedGroupByStr
}
override def getMetaData(): Map[String, String] = {
- super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
+ super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++
+ Map("PushedAggregation" -> pushedAggregationsStr) ++
+ Map("PushedGroupBy" -> pushedGroupByStr)
+ }
+
+ private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
+ a.aggregateExpressions.sortBy(_.hashCode())
+ .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
+ a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
index 9a0e4b4..c579867 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
@@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2.parquet
import scala.collection.JavaConverters._
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.Scan
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter}
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class ParquetScanBuilder(
@@ -35,7 +37,8 @@ case class ParquetScanBuilder(
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
+ extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
+ with SupportsPushDownAggregates{
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
@@ -70,6 +73,10 @@ case class ParquetScanBuilder(
}
}
+ private var finalSchema = new StructType()
+
+ private var pushedAggregations = Option.empty[Aggregation]
+
override protected val supportsNestedSchemaPruning: Boolean = true
override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters
@@ -79,8 +86,87 @@ case class ParquetScanBuilder(
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+ def getStructFieldForCol(col: NamedReference): StructField = {
+ schema.nameToField(col.fieldNames.head)
+ }
+
+ def isPartitionCol(col: NamedReference) = {
+ partitionNameSet.contains(col.fieldNames.head)
+ }
+
+ def processMinOrMax(agg: AggregateFunc): Boolean = {
+ val (column, aggType) = agg match {
+ case max: Max => (max.column, "max")
+ case min: Min => (min.column, "min")
+ case _ =>
+ throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}")
+ }
+
+ if (isPartitionCol(column)) {
+ // don't push down partition column, footer doesn't have max/min for partition column
+ return false
+ }
+ val structField = getStructFieldForCol(column)
+
+ structField.dataType match {
+ // not push down complex type
+ // not push down Timestamp because INT96 sort order is undefined,
+ // Parquet doesn't return statistics for INT96
+ case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType =>
+ false
+ case _ =>
+ finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")"))
+ true
+ }
+ }
+
+ if (!sparkSession.sessionState.conf.parquetAggregatePushDown ||
+ aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) {
+ // Parquet footer has max/min/count for columns
+ // e.g. SELECT COUNT(col1) FROM t
+ // but footer doesn't have max/min/count for a column if max/min/count
+ // are combined with filter or group by
+ // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
+ // SELECT COUNT(col1) FROM t GROUP BY col2
+ // Todo: 1. add support if groupby column is partition col
+ // (https://issues.apache.org/jira/browse/SPARK-36646)
+ // 2. add support if filter col is partition col
+ // (https://issues.apache.org/jira/browse/SPARK-36647)
+ return false
+ }
+
+ aggregation.groupByColumns.foreach { col =>
+ if (col.fieldNames.length != 1) return false
+ finalSchema = finalSchema.add(getStructFieldForCol(col))
+ }
+
+ aggregation.aggregateExpressions.foreach {
+ case max: Max =>
+ if (!processMinOrMax(max)) return false
+ case min: Min =>
+ if (!processMinOrMax(min)) return false
+ case count: Count =>
+ if (count.column.fieldNames.length != 1 || count.isDistinct) return false
+ finalSchema =
+ finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType))
+ case _: CountStar =>
+ finalSchema = finalSchema.add(StructField("count(*)", LongType))
+ case _ =>
+ return false
+ }
+ this.pushedAggregations = Some(aggregation)
+ true
+ }
+
override def build(): Scan = {
- ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(),
- readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters)
+ // the `finalSchema` is either pruned in pushAggregation (if aggregates are
+ // pushed down), or pruned in readDataSchema() (in regular column pruning). These
+ // two are mutual exclusive.
+ if (pushedAggregations.isEmpty) finalSchema = readDataSchema()
+ ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
+ readPartitionSchema(), pushedParquetFilters, options, pushedAggregations,
+ partitionFilters, dataFilters)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
index d0877db..604a892 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
@@ -354,7 +354,7 @@ class FileScanSuite extends FileScanSuiteBase {
val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
("ParquetScan",
(s, fi, ds, rds, rps, f, o, pf, df) =>
- ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df),
+ ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df),
Seq.empty),
("OrcScan",
(s, fi, ds, rds, rps, f, o, pf, df) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
new file mode 100644
index 0000000..c795bd9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
@@ -0,0 +1,518 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.functions.min
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+/**
+ * A test suite that tests Max/Min/Count push down.
+ */
+abstract class ParquetAggregatePushDownSuite
+ extends QueryTest
+ with ParquetTest
+ with SharedSparkSession
+ with ExplainSuiteHelper {
+ import testImplicits._
+
+ test("aggregate push down - nested column: Max(top level column) not push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val max = sql("SELECT Max(_1) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Count(top level column) push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val count = sql("SELECT Count(_1) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(_1)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Max(nested column) not push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val max = sql("SELECT Max(_1._2[0]) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - nested column: Count(nested column) not push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ withParquetTable(data, "t") {
+ val count = sql("SELECT Count(_1._2[0]) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("aggregate push down - Max(partition Col): not push dow") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+ spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val max = sql("SELECT Max(p) FROM tmp")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ checkAnswer(max, Seq(Row(2)))
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - Count(partition Col): push down") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+ spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+ val count = sql("SELECT COUNT(p) FROM tmp")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(p)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - Filter alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1), MAX(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(7)))
+ }
+ }
+ }
+
+ test("aggregate push down - alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-1, 0)))
+ }
+ }
+ }
+
+ test("aggregate push down - aggregate over alias not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val df = spark.table("t")
+ val query = df.select($"_1".as("col1")).agg(min($"col1"))
+ query.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []" // aggregate alias not pushed down
+ checkKeywordsExistsInExplain(query, expected_plan_fragment)
+ }
+ checkAnswer(query, Seq(Row(-2)))
+ }
+ }
+ }
+
+ test("aggregate push down - query with group by not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // aggregate not pushed down if there is group by
+ val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3)))
+ }
+ }
+ }
+
+ test("aggregate push down - query with filter not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // aggregate not pushed down if there is filter
+ val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(2)))
+ }
+ }
+ }
+
+ test("aggregate push down - push down only if all the aggregates can be pushed down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ // not push down since sum can't be pushed down
+ val selectAgg = sql("SELECT min(_1), sum(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2, 41)))
+ }
+ }
+ }
+
+ test("aggregate push down - MIN/MAX/COUNT") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withParquetTable(data, "t") {
+ withSQLConf(
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+ val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," +
+ " count(*), count(_1), count(_2), count(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_3), " +
+ "MAX(_3), " +
+ "MIN(_1), " +
+ "MAX(_1), " +
+ "COUNT(*), " +
+ "COUNT(_1), " +
+ "COUNT(_2), " +
+ "COUNT(_3)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+
+ checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6)))
+ }
+ }
+ }
+
+ test("aggregate push down - different data types") {
+ implicit class StringToDate(s: String) {
+ def date: Date = Date.valueOf(s)
+ }
+
+ implicit class StringToTs(s: String) {
+ def ts: Timestamp = Timestamp.valueOf(s)
+ }
+
+ val rows =
+ Seq(
+ Row(
+ "a string",
+ true,
+ 10.toByte,
+ "Spark SQL".getBytes,
+ 12.toShort,
+ 3,
+ Long.MaxValue,
+ 0.15.toFloat,
+ 0.75D,
+ Decimal("12.345678"),
+ ("2021-01-01").date,
+ ("2015-01-01 23:50:59.123").ts),
+ Row(
+ "test string",
+ false,
+ 1.toByte,
+ "Parquet".getBytes,
+ 2.toShort,
+ null,
+ Long.MinValue,
+ 0.25.toFloat,
+ 0.85D,
+ Decimal("1.2345678"),
+ ("2015-01-01").date,
+ ("2021-01-01 23:50:59.123").ts),
+ Row(
+ null,
+ true,
+ 10000.toByte,
+ "Spark ML".getBytes,
+ 222.toShort,
+ 113,
+ 11111111L,
+ 0.25.toFloat,
+ 0.75D,
+ Decimal("12345.678"),
+ ("2004-06-19").date,
+ ("1999-08-26 10:43:59.123").ts)
+ )
+
+ val schema = StructType(List(StructField("StringCol", StringType, true),
+ StructField("BooleanCol", BooleanType, false),
+ StructField("ByteCol", ByteType, false),
+ StructField("BinaryCol", BinaryType, false),
+ StructField("ShortCol", ShortType, false),
+ StructField("IntegerCol", IntegerType, true),
+ StructField("LongCol", LongType, false),
+ StructField("FloatCol", FloatType, false),
+ StructField("DoubleCol", DoubleType, false),
+ StructField("DecimalCol", DecimalType(25, 5), true),
+ StructField("DateCol", DateType, false),
+ StructField("TimestampCol", TimestampType, false)).toArray)
+
+ val rdd = sparkContext.parallelize(rows)
+ withTempPath { file =>
+ spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath)
+ withTempView("test") {
+ spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test")
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+
+ val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " +
+ "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " +
+ "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type
+ // so aggregates are not pushed down
+ testMinWithTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes,
+ 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457,
+ ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)))
+
+ val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " +
+ "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " +
+ "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test")
+
+ testMinWithOutTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(StringCol), " +
+ "MIN(BooleanCol), " +
+ "MIN(ByteCol), " +
+ "MIN(BinaryCol), " +
+ "MIN(ShortCol), " +
+ "MIN(IntegerCol), " +
+ "MIN(LongCol), " +
+ "MIN(FloatCol), " +
+ "MIN(DoubleCol), " +
+ "MIN(DecimalCol), " +
+ "MIN(DateCol)]"
+ checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes,
+ 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457,
+ ("2004-06-19").date)))
+
+ val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " +
+ "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " +
+ "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type
+ // so aggregates are not pushed down
+ testMaxWithTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte,
+ "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D,
+ 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)))
+
+ val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " +
+ "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " +
+ "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test")
+
+ testMaxWithoutTS.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(StringCol), " +
+ "MAX(BooleanCol), " +
+ "MAX(ByteCol), " +
+ "MAX(BinaryCol), " +
+ "MAX(ShortCol), " +
+ "MAX(IntegerCol), " +
+ "MAX(LongCol), " +
+ "MAX(FloatCol), " +
+ "MAX(DoubleCol), " +
+ "MAX(DecimalCol), " +
+ "MAX(DateCol)]"
+ checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte,
+ "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D,
+ 12345.678, ("2021-01-01").date)))
+
+ val testCount = sql("SELECT count(StringCol), count(BooleanCol)," +
+ " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," +
+ " count(LongCol), count(FloatCol), count(DoubleCol)," +
+ " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test")
+
+ testCount.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [" +
+ "COUNT(StringCol), " +
+ "COUNT(BooleanCol), " +
+ "COUNT(ByteCol), " +
+ "COUNT(BinaryCol), " +
+ "COUNT(ShortCol), " +
+ "COUNT(IntegerCol), " +
+ "COUNT(LongCol), " +
+ "COUNT(FloatCol), " +
+ "COUNT(DoubleCol), " +
+ "COUNT(DecimalCol), " +
+ "COUNT(DateCol), " +
+ "COUNT(TimestampCol)]"
+ checkKeywordsExistsInExplain(testCount, expected_plan_fragment)
+ }
+
+ checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3)))
+ }
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - column name case sensitivity") {
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").parquet(dir.getCanonicalPath)
+ withTempView("tmp") {
+ spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+ val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(id), MIN(id)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(9, 0)))
+ }
+ }
+ }
+ }
+ }
+}
+
+class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super
+ .sparkConf
+ .set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
+}
+
+class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super
+ .sparkConf
+ .set(SQLConf.USE_V1_SOURCE_LIST, "")
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org