You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2021/10/11 05:20:59 UTC

[spark] branch master updated: [SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet

This is an automated email from the ASF dual-hosted git repository.

viirya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 128168d  [SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet
128168d is described below

commit 128168d8c4019a1e10a9f1be734868524f6a09f0
Author: Huaxin Gao <hu...@apple.com>
AuthorDate: Sun Oct 10 22:20:09 2021 -0700

    [SPARK-36645][SQL] Aggregate (Min/Max/Count) push down for Parquet
    
    ### What changes were proposed in this pull request?
    Push down Min/Max/Count to Parquet with the following restrictions:
    
    - nested types such as Array, Map or Struct will not be pushed down
    - Timestamp not pushed down because INT96 sort order is undefined, Parquet doesn't return statistics for INT96
    - If the aggregate column is on partition column, only Count will be pushed, Min or Max will not be pushed down because Parquet doesn't return max/min for partition column.
    - If somehow the file doesn't have stats for the aggregate columns, Spark will throw Exception.
    - Currently, if filter/GROUP BY is involved, Min/Max/Count will not be pushed down, but the restriction will be lifted if the filter or GROUP BY is on partition column (https://issues.apache.org/jira/browse/SPARK-36646 and https://issues.apache.org/jira/browse/SPARK-36647)
    
    ### Why are the changes needed?
    Since parquet has the statistics information for min, max and count, we want to take advantage of this info and push down Min/Max/Count to parquet layer for better performance.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, `SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED` was added. If sets to true, we will push down Min/Max/Count to Parquet.
    
    ### How was this patch tested?
    new test suites
    
    Closes #33639 from huaxingao/parquet_agg.
    
    Authored-by: Huaxin Gao <hu...@apple.com>
    Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  10 +
 .../org/apache/spark/sql/types/StructType.scala    |   2 +-
 .../datasources/parquet/ParquetUtils.scala         | 227 +++++++++
 .../execution/datasources/v2/FileScanBuilder.scala |   2 +-
 .../v2/parquet/ParquetPartitionReaderFactory.scala | 123 ++++-
 .../datasources/v2/parquet/ParquetScan.scala       |  37 +-
 .../v2/parquet/ParquetScanBuilder.scala            |  96 +++-
 .../scala/org/apache/spark/sql/FileScanSuite.scala |   2 +-
 .../parquet/ParquetAggregatePushDownSuite.scala    | 518 +++++++++++++++++++++
 9 files changed, 984 insertions(+), 33 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6443dfd..98aad1c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -853,6 +853,14 @@ object SQLConf {
       .checkValue(threshold => threshold >= 0, "The threshold must not be negative.")
       .createWithDefault(10)
 
+  val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown")
+    .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
+      " down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" +
+      " can't be pushed down")
+    .version("3.3.0")
+    .booleanConf
+    .createWithDefault(false)
+
   val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
     .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " +
       "values will be written in Apache Parquet's fixed-length byte array format, which other " +
@@ -3660,6 +3668,8 @@ class SQLConf extends Serializable with Logging {
   def parquetFilterPushDownInFilterThreshold: Int =
     getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD)
 
+  def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED)
+
   def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)
 
   def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index c9862cb..50b197f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -115,7 +115,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
   def names: Array[String] = fieldNames
 
   private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
-  private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
+  private[sql] lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
   private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap
 
   override def equals(that: Any): Boolean = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index b91d75c..1093f9c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -16,11 +16,28 @@
  */
 package org.apache.spark.sql.execution.datasources.parquet
 
+import java.util
+
+import scala.collection.mutable
+import scala.language.existentials
+
 import org.apache.hadoop.fs.{FileStatus, Path}
 import org.apache.parquet.hadoop.ParquetFileWriter
+import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata}
+import org.apache.parquet.io.api.Binary
+import org.apache.parquet.schema.{PrimitiveType, Types}
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
 
+import org.apache.spark.SparkException
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.datasources.PartitioningUtils
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
+import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
 
 object ParquetUtils {
   def inferSchema(
@@ -127,4 +144,214 @@ object ParquetUtils {
     file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
       file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to
+   * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want
+   * to get the partial aggregates (Max/Min/Count) result using the statistics information
+   * from Parquet footer file, and then construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  private[sql] def createAggInternalRowFromFooter(
+      footer: ParquetMetadata,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+      isCaseSensitive: Boolean): InternalRow = {
+    val (primitiveTypes, values) = getPushedDownAggResult(
+      footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive)
+
+    val builder = Types.buildMessage
+    primitiveTypes.foreach(t => builder.addField(t))
+    val parquetSchema = builder.named("root")
+
+    val schemaConverter = new ParquetToSparkSchemaConverter
+    val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema,
+      None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
+    val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
+    primitiveTypeNames.zipWithIndex.foreach {
+      case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+        val v = values(i).asInstanceOf[Boolean]
+        converter.getConverter(i).asPrimitiveConverter.addBoolean(v)
+      case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+        val v = values(i).asInstanceOf[Integer]
+        converter.getConverter(i).asPrimitiveConverter.addInt(v)
+      case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+        val v = values(i).asInstanceOf[Long]
+        converter.getConverter(i).asPrimitiveConverter.addLong(v)
+      case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+        val v = values(i).asInstanceOf[Float]
+        converter.getConverter(i).asPrimitiveConverter.addFloat(v)
+      case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+        val v = values(i).asInstanceOf[Double]
+        converter.getConverter(i).asPrimitiveConverter.addDouble(v)
+      case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+        val v = values(i).asInstanceOf[Binary]
+        converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+      case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+        val v = values(i).asInstanceOf[Binary]
+        converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+      case (_, i) =>
+        throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
+    }
+    converter.currentRecord
+  }
+
+  /**
+   * When the aggregates (Max/Min/Count) are pushed down to Parquet, in the case of
+   * PARQUET_VECTORIZED_READER_ENABLED sets to true, we don't need buildColumnarReader
+   * to read data from Parquet and aggregate at Spark layer. Instead we want
+   * to get the aggregates (Max/Min/Count) result using the statistics information
+   * from Parquet footer file, and then construct a ColumnarBatch from these aggregate results.
+   *
+   * @return Aggregate results in the format of ColumnarBatch
+   */
+  private[sql] def createAggColumnarBatchFromFooter(
+      footer: ParquetMetadata,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      offHeap: Boolean,
+      datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+      isCaseSensitive: Boolean): ColumnarBatch = {
+    val row = createAggInternalRowFromFooter(
+      footer,
+      filePath,
+      dataSchema,
+      partitionSchema,
+      aggregation,
+      aggSchema,
+      datetimeRebaseMode,
+      isCaseSensitive)
+    val converter = new RowToColumnConverter(aggSchema)
+    val columnVectors = if (offHeap) {
+      OffHeapColumnVector.allocateColumns(1, aggSchema)
+    } else {
+      OnHeapColumnVector.allocateColumns(1, aggSchema)
+    }
+    converter.convert(row, columnVectors.toArray)
+    new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
+  }
+
+  /**
+   * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics
+   * information from Parquet footer file.
+   *
+   * @return A tuple of `Array[PrimitiveType]` and Array[Any].
+   *         The first element is the Parquet PrimitiveType of the aggregate column,
+   *         and the second element is the aggregated value.
+   */
+  private[sql] def getPushedDownAggResult(
+      footer: ParquetMetadata,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      isCaseSensitive: Boolean)
+  : (Array[PrimitiveType], Array[Any]) = {
+    val footerFileMetaData = footer.getFileMetaData
+    val fields = footerFileMetaData.getSchema.getFields
+    val blocks = footer.getBlocks
+    val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
+    val valuesBuilder = mutable.ArrayBuilder.make[Any]
+
+    assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down")
+    aggregation.aggregateExpressions.foreach { agg =>
+      var value: Any = None
+      var rowCount = 0L
+      var isCount = false
+      var index = 0
+      var schemaName = ""
+      blocks.forEach { block =>
+        val blockMetaData = block.getColumns
+        agg match {
+          case max: Max =>
+            val colName = max.column.fieldNames.head
+            index = dataSchema.fieldNames.toList.indexOf(colName)
+            schemaName = "max(" + colName + ")"
+            val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true)
+            if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) {
+              value = currentMax
+            }
+          case min: Min =>
+            val colName = min.column.fieldNames.head
+            index = dataSchema.fieldNames.toList.indexOf(colName)
+            schemaName = "min(" + colName + ")"
+            val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false)
+            if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) {
+              value = currentMin
+            }
+          case count: Count =>
+            schemaName = "count(" + count.column.fieldNames.head + ")"
+            rowCount += block.getRowCount
+            var isPartitionCol = false
+            if (partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive))
+              .toSet.contains(count.column.fieldNames.head)) {
+              isPartitionCol = true
+            }
+            isCount = true
+            if (!isPartitionCol) {
+              index = dataSchema.fieldNames.toList.indexOf(count.column.fieldNames.head)
+              // Count(*) includes the null values, but Count(colName) doesn't.
+              rowCount -= getNumNulls(filePath, blockMetaData, index)
+            }
+          case _: CountStar =>
+            schemaName = "count(*)"
+            rowCount += block.getRowCount
+            isCount = true
+          case _ =>
+        }
+      }
+      if (isCount) {
+        valuesBuilder += rowCount
+        primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName);
+      } else {
+        valuesBuilder += value
+        val field = fields.get(index)
+        primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName)
+          .as(field.getLogicalTypeAnnotation)
+          .length(field.asPrimitiveType.getTypeLength)
+          .named(schemaName)
+      }
+    }
+    (primitiveTypeBuilder.result, valuesBuilder.result)
+  }
+
+  /**
+   * Get the Max or Min value for ith column in the current block
+   *
+   * @return the Max or Min value
+   */
+  private def getCurrentBlockMaxOrMin(
+      filePath: String,
+      columnChunkMetaData: util.List[ColumnChunkMetaData],
+      i: Int,
+      isMax: Boolean): Any = {
+    val statistics = columnChunkMetaData.get(i).getStatistics
+    if (!statistics.hasNonNullValue) {
+      throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " +
+        s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again")
+    } else {
+      if (isMax) statistics.genericGetMax else statistics.genericGetMin
+    }
+  }
+
+  private def getNumNulls(
+      filePath: String,
+      columnChunkMetaData: util.List[ColumnChunkMetaData],
+      i: Int): Long = {
+    val statistics = columnChunkMetaData.get(i).getStatistics
+    if (!statistics.isNumNullsSet) {
+      throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" +
+        s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" +
+        s" again")
+    }
+    statistics.getNumNulls;
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
index 309f045..2dc4137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
@@ -96,6 +96,6 @@ abstract class FileScanBuilder(
   private def createRequiredNameSet(): Set[String] =
     requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
 
-  private val partitionNameSet: Set[String] =
+  val partitionNameSet: Set[String] =
     partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
index 058669b..111018b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
@@ -25,14 +25,16 @@ import org.apache.hadoop.mapreduce._
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
 import org.apache.parquet.filter2.compat.FilterCompat
 import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate}
-import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS
+import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS}
 import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader}
+import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
 
 import org.apache.spark.TaskContext
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
 import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator}
 import org.apache.spark.sql.execution.datasources.parquet._
@@ -53,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration
  * @param readDataSchema Required schema of Parquet files.
  * @param partitionSchema Schema of partitions.
  * @param filters Filters to be pushed down in the batch scan.
+ * @param aggregation Aggregation to be pushed down in the batch scan.
  * @param parquetOptions The options of Parquet datasource that are set for the read.
  */
 case class ParquetPartitionReaderFactory(
@@ -62,6 +65,7 @@ case class ParquetPartitionReaderFactory(
     readDataSchema: StructType,
     partitionSchema: StructType,
     filters: Array[Filter],
+    aggregation: Option[Aggregation],
     parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging {
   private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
   private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields)
@@ -80,6 +84,30 @@ case class ParquetPartitionReaderFactory(
   private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead
   private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead
 
+  private def getFooter(file: PartitionedFile): ParquetMetadata = {
+    val conf = broadcastedConf.value.value
+    val filePath = new Path(new URI(file.filePath))
+
+    if (aggregation.isEmpty) {
+      ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS)
+    } else {
+      // For aggregate push down, we will get max/min/count from footer statistics.
+      // We want to read the footer for the whole file instead of reading multiple
+      // footers for every split of the file. Basically if the start (the beginning of)
+      // the offset in PartitionedFile is 0, we will read the footer. Otherwise, it means
+      // that we have already read footer for that file, so we will skip reading again.
+      if (file.start != 0) return null
+      ParquetFooterReader.readFooter(conf, filePath, NO_FILTER)
+    }
+  }
+
+  private def getDatetimeRebaseMode(
+      footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = {
+    DataSourceUtils.datetimeRebaseMode(
+      footerFileMetaData.getKeyValueMetaData.get,
+      datetimeRebaseModeInRead)
+  }
+
   override def supportColumnarReads(partition: InputPartition): Boolean = {
     sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
       resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
@@ -87,18 +115,44 @@ case class ParquetPartitionReaderFactory(
   }
 
   override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
-    val reader = if (enableVectorizedReader) {
-      createVectorizedReader(file)
-    } else {
-      createRowBaseReader(file)
-    }
+    val fileReader = if (aggregation.isEmpty) {
+      val reader = if (enableVectorizedReader) {
+        createVectorizedReader(file)
+      } else {
+        createRowBaseReader(file)
+      }
+
+      new PartitionReader[InternalRow] {
+        override def next(): Boolean = reader.nextKeyValue()
 
-    val fileReader = new PartitionReader[InternalRow] {
-      override def next(): Boolean = reader.nextKeyValue()
+        override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow]
 
-      override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow]
+        override def close(): Unit = reader.close()
+      }
+    } else {
+      new PartitionReader[InternalRow] {
+        private var hasNext = true
+        private lazy val row: InternalRow = {
+          val footer = getFooter(file)
+          if (footer != null && footer.getBlocks.size > 0) {
+            ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema,
+              partitionSchema, aggregation.get, readDataSchema,
+              getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+          } else {
+            null
+          }
+        }
+        override def next(): Boolean = {
+          hasNext && row != null
+        }
 
-      override def close(): Unit = reader.close()
+        override def get(): InternalRow = {
+          hasNext = false
+          row
+        }
+
+        override def close(): Unit = {}
+      }
     }
 
     new PartitionReaderWithPartitionValues(fileReader, readDataSchema,
@@ -106,17 +160,45 @@ case class ParquetPartitionReaderFactory(
   }
 
   override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
-    val vectorizedReader = createVectorizedReader(file)
-    vectorizedReader.enableReturningBatches()
+    val fileReader = if (aggregation.isEmpty) {
+      val vectorizedReader = createVectorizedReader(file)
+      vectorizedReader.enableReturningBatches()
+
+      new PartitionReader[ColumnarBatch] {
+        override def next(): Boolean = vectorizedReader.nextKeyValue()
 
-    new PartitionReader[ColumnarBatch] {
-      override def next(): Boolean = vectorizedReader.nextKeyValue()
+        override def get(): ColumnarBatch =
+          vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
 
-      override def get(): ColumnarBatch =
-        vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
+        override def close(): Unit = vectorizedReader.close()
+      }
+    } else {
+      new PartitionReader[ColumnarBatch] {
+        private var hasNext = true
+        private val row: ColumnarBatch = {
+          val footer = getFooter(file)
+          if (footer != null && footer.getBlocks.size > 0) {
+            ParquetUtils.createAggColumnarBatchFromFooter(footer, file.filePath, dataSchema,
+              partitionSchema, aggregation.get, readDataSchema, enableOffHeapColumnVector,
+              getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+          } else {
+            null
+          }
+        }
+
+        override def next(): Boolean = {
+          hasNext && row != null
+        }
+
+        override def get(): ColumnarBatch = {
+          hasNext = false
+          row
+        }
 
-      override def close(): Unit = vectorizedReader.close()
+        override def close(): Unit = {}
+      }
     }
+    fileReader
   }
 
   private def buildReaderBase[T](
@@ -131,11 +213,8 @@ case class ParquetPartitionReaderFactory(
     val filePath = new Path(new URI(file.filePath))
     val split = new FileSplit(filePath, file.start, file.length, Array.empty[String])
 
-    lazy val footerFileMetaData =
-      ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData
-    val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
-      footerFileMetaData.getKeyValueMetaData.get,
-      datetimeRebaseModeInRead)
+    lazy val footerFileMetaData = getFooter(file).getFileMetaData
+    val datetimeRebaseMode = getDatetimeRebaseMode(footerFileMetaData)
     // Try to push down filters when filter push-down is enabled.
     val pushed = if (enableParquetFilterPushDown) {
       val parquetSchema = footerFileMetaData.getSchema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index e277e33..42dc287 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -24,6 +24,7 @@ import org.apache.parquet.hadoop.ParquetInputFormat
 
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.PartitionReaderFactory
 import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
@@ -43,10 +44,17 @@ case class ParquetScan(
     readPartitionSchema: StructType,
     pushedFilters: Array[Filter],
     options: CaseInsensitiveStringMap,
+    pushedAggregate: Option[Aggregation] = None,
     partitionFilters: Seq[Expression] = Seq.empty,
     dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
   override def isSplitable(path: Path): Boolean = true
 
+  override def readSchema(): StructType = {
+    // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder`
+    // and no need to call super.readSchema()
+    if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema()
+  }
+
   override def createReaderFactory(): PartitionReaderFactory = {
     val readDataSchemaAsJson = readDataSchema.json
     hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
@@ -86,23 +94,46 @@ case class ParquetScan(
       readDataSchema,
       readPartitionSchema,
       pushedFilters,
+      pushedAggregate,
       new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf))
   }
 
   override def equals(obj: Any): Boolean = obj match {
     case p: ParquetScan =>
+      val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) {
+        equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get)
+      } else {
+        pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
+      }
       super.equals(p) && dataSchema == p.dataSchema && options == p.options &&
-        equivalentFilters(pushedFilters, p.pushedFilters)
+        equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual
     case _ => false
   }
 
   override def hashCode(): Int = getClass.hashCode()
 
+  lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
+    (seqToString(pushedAggregate.get.aggregateExpressions),
+      seqToString(pushedAggregate.get.groupByColumns))
+  } else {
+    ("[]", "[]")
+  }
+
   override def description(): String = {
-    super.description() + ", PushedFilters: " + seqToString(pushedFilters)
+    super.description() + ", PushedFilters: " + seqToString(pushedFilters) +
+      ", PushedAggregation: " + pushedAggregationsStr +
+      ", PushedGroupBy: " + pushedGroupByStr
   }
 
   override def getMetaData(): Map[String, String] = {
-    super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
+    super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++
+      Map("PushedAggregation" -> pushedAggregationsStr) ++
+      Map("PushedGroupBy" -> pushedGroupByStr)
+  }
+
+  private def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
+    a.aggregateExpressions.sortBy(_.hashCode())
+      .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
+      a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
index 9a0e4b4..c579867 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
@@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.datasources.v2.parquet
 import scala.collection.JavaConverters._
 
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.Scan
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates}
 import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
 import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter}
 import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
 import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
 import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructField, StructType, TimestampType}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 case class ParquetScanBuilder(
@@ -35,7 +37,8 @@ case class ParquetScanBuilder(
     schema: StructType,
     dataSchema: StructType,
     options: CaseInsensitiveStringMap)
-  extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
+  extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
+    with SupportsPushDownAggregates{
   lazy val hadoopConf = {
     val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
     // Hadoop Configurations are case sensitive.
@@ -70,6 +73,10 @@ case class ParquetScanBuilder(
     }
   }
 
+  private var finalSchema = new StructType()
+
+  private var pushedAggregations = Option.empty[Aggregation]
+
   override protected val supportsNestedSchemaPruning: Boolean = true
 
   override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters
@@ -79,8 +86,87 @@ case class ParquetScanBuilder(
   // All filters that can be converted to Parquet are pushed down.
   override def pushedFilters(): Array[Filter] = pushedParquetFilters
 
+  override def pushAggregation(aggregation: Aggregation): Boolean = {
+
+    def getStructFieldForCol(col: NamedReference): StructField = {
+      schema.nameToField(col.fieldNames.head)
+    }
+
+    def isPartitionCol(col: NamedReference) = {
+      partitionNameSet.contains(col.fieldNames.head)
+    }
+
+    def processMinOrMax(agg: AggregateFunc): Boolean = {
+      val (column, aggType) = agg match {
+        case max: Max => (max.column, "max")
+        case min: Min => (min.column, "min")
+        case _ =>
+          throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}")
+      }
+
+      if (isPartitionCol(column)) {
+        // don't push down partition column, footer doesn't have max/min for partition column
+        return false
+      }
+      val structField = getStructFieldForCol(column)
+
+      structField.dataType match {
+        // not push down complex type
+        // not push down Timestamp because INT96 sort order is undefined,
+        // Parquet doesn't return statistics for INT96
+        case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType =>
+          false
+        case _ =>
+          finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")"))
+          true
+      }
+    }
+
+    if (!sparkSession.sessionState.conf.parquetAggregatePushDown ||
+      aggregation.groupByColumns.nonEmpty || dataFilters.length > 0) {
+      // Parquet footer has max/min/count for columns
+      // e.g. SELECT COUNT(col1) FROM t
+      // but footer doesn't have max/min/count for a column if max/min/count
+      // are combined with filter or group by
+      // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
+      //      SELECT COUNT(col1) FROM t GROUP BY col2
+      // Todo: 1. add support if groupby column is partition col
+      //          (https://issues.apache.org/jira/browse/SPARK-36646)
+      //       2. add support if filter col is partition col
+      //          (https://issues.apache.org/jira/browse/SPARK-36647)
+      return false
+    }
+
+    aggregation.groupByColumns.foreach { col =>
+      if (col.fieldNames.length != 1) return false
+      finalSchema = finalSchema.add(getStructFieldForCol(col))
+    }
+
+    aggregation.aggregateExpressions.foreach {
+      case max: Max =>
+        if (!processMinOrMax(max)) return false
+      case min: Min =>
+        if (!processMinOrMax(min)) return false
+      case count: Count =>
+        if (count.column.fieldNames.length != 1 || count.isDistinct) return false
+        finalSchema =
+          finalSchema.add(StructField(s"count(" + count.column.fieldNames.head + ")", LongType))
+      case _: CountStar =>
+        finalSchema = finalSchema.add(StructField("count(*)", LongType))
+      case _ =>
+        return false
+    }
+    this.pushedAggregations = Some(aggregation)
+    true
+  }
+
   override def build(): Scan = {
-    ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(),
-      readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters)
+    // the `finalSchema` is either pruned in pushAggregation (if aggregates are
+    // pushed down), or pruned in readDataSchema() (in regular column pruning). These
+    // two are mutual exclusive.
+    if (pushedAggregations.isEmpty) finalSchema = readDataSchema()
+    ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
+      readPartitionSchema(), pushedParquetFilters, options, pushedAggregations,
+      partitionFilters, dataFilters)
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
index d0877db..604a892 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
@@ -354,7 +354,7 @@ class FileScanSuite extends FileScanSuiteBase {
   val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
     ("ParquetScan",
       (s, fi, ds, rds, rps, f, o, pf, df) =>
-        ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df),
+        ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df),
       Seq.empty),
     ("OrcScan",
       (s, fi, ds, rds, rps, f, o, pf, df) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
new file mode 100644
index 0000000..c795bd9
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAggregatePushDownSuite.scala
@@ -0,0 +1,518 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.functions.min
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+/**
+ * A test suite that tests Max/Min/Count push down.
+ */
+abstract class ParquetAggregatePushDownSuite
+  extends QueryTest
+  with ParquetTest
+  with SharedSparkSession
+  with ExplainSuiteHelper {
+  import testImplicits._
+
+  test("aggregate push down - nested column: Max(top level column) not push down") {
+    val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+    withSQLConf(
+      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+      withParquetTable(data, "t") {
+        val max = sql("SELECT Max(_1) FROM t")
+        max.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: []"
+            checkKeywordsExistsInExplain(max, expected_plan_fragment)
+        }
+      }
+    }
+  }
+
+  test("aggregate push down - nested column: Count(top level column) push down") {
+    val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+    withSQLConf(
+      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+      withParquetTable(data, "t") {
+        val count = sql("SELECT Count(_1) FROM t")
+        count.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: [COUNT(_1)]"
+            checkKeywordsExistsInExplain(count, expected_plan_fragment)
+        }
+        checkAnswer(count, Seq(Row(10)))
+      }
+    }
+  }
+
+  test("aggregate push down - nested column: Max(nested column) not push down") {
+    val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+    withSQLConf(
+      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+      withParquetTable(data, "t") {
+        val max = sql("SELECT Max(_1._2[0]) FROM t")
+        max.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: []"
+            checkKeywordsExistsInExplain(max, expected_plan_fragment)
+        }
+      }
+    }
+  }
+
+  test("aggregate push down - nested column: Count(nested column) not push down") {
+    val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+    withSQLConf(
+      SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+      withParquetTable(data, "t") {
+        val count = sql("SELECT Count(_1._2[0]) FROM t")
+        count.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: []"
+            checkKeywordsExistsInExplain(count, expected_plan_fragment)
+        }
+        checkAnswer(count, Seq(Row(10)))
+      }
+    }
+  }
+
+  test("aggregate push down - Max(partition Col): not push dow") {
+    withTempPath { dir =>
+      spark.range(10).selectExpr("id", "id % 3 as p")
+        .write.partitionBy("p").parquet(dir.getCanonicalPath)
+      withTempView("tmp") {
+        spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+        withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+          val max = sql("SELECT Max(p) FROM tmp")
+          max.queryExecution.optimizedPlan.collect {
+            case _: DataSourceV2ScanRelation =>
+              val expected_plan_fragment =
+                "PushedAggregation: []"
+              checkKeywordsExistsInExplain(max, expected_plan_fragment)
+          }
+          checkAnswer(max, Seq(Row(2)))
+        }
+      }
+    }
+  }
+
+  test("aggregate push down - Count(partition Col): push down") {
+    withTempPath { dir =>
+      spark.range(10).selectExpr("id", "id % 3 as p")
+        .write.partitionBy("p").parquet(dir.getCanonicalPath)
+      withTempView("tmp") {
+        spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+        val enableVectorizedReader = Seq("false", "true")
+        for (testVectorizedReader <- enableVectorizedReader) {
+          withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+            vectorizedReaderEnabledKey -> testVectorizedReader) {
+            val count = sql("SELECT COUNT(p) FROM tmp")
+            count.queryExecution.optimizedPlan.collect {
+              case _: DataSourceV2ScanRelation =>
+                val expected_plan_fragment =
+                  "PushedAggregation: [COUNT(p)]"
+                checkKeywordsExistsInExplain(count, expected_plan_fragment)
+            }
+            checkAnswer(count, Seq(Row(10)))
+          }
+        }
+      }
+    }
+  }
+
+  test("aggregate push down - Filter alias over aggregate") {
+    val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+      (9, "mno", 7), (2, null, 6))
+    withParquetTable(data, "t") {
+      withSQLConf(
+        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1")
+        selectAgg.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: [MIN(_1), MAX(_1)]"
+            checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+        }
+        checkAnswer(selectAgg, Seq(Row(7)))
+      }
+    }
+  }
+
+  test("aggregate push down - alias over aggregate") {
+    val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+      (9, "mno", 7), (2, null, 6))
+    withParquetTable(data, "t") {
+      withSQLConf(
+        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t")
+        selectAgg.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: [MIN(_1)]"
+            checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+        }
+        checkAnswer(selectAgg, Seq(Row(-1, 0)))
+      }
+    }
+  }
+
+  test("aggregate push down - aggregate over alias not push down") {
+    val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+      (9, "mno", 7), (2, null, 6))
+    withParquetTable(data, "t") {
+      withSQLConf(
+        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        val df = spark.table("t")
+        val query = df.select($"_1".as("col1")).agg(min($"col1"))
+        query.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: []"  // aggregate alias not pushed down
+            checkKeywordsExistsInExplain(query, expected_plan_fragment)
+        }
+        checkAnswer(query, Seq(Row(-2)))
+      }
+    }
+  }
+
+  test("aggregate push down - query with group by not push down") {
+    val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+      (9, "mno", 7), (2, null, 7))
+    withParquetTable(data, "t") {
+      withSQLConf(
+        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        // aggregate not pushed down if there is group by
+        val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ")
+        selectAgg.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: []"
+            checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+        }
+        checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3)))
+      }
+    }
+  }
+
+  test("aggregate push down - query with filter not push down") {
+    val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+      (9, "mno", 7), (2, null, 7))
+    withParquetTable(data, "t") {
+      withSQLConf(
+        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        // aggregate not pushed down if there is filter
+        val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0")
+        selectAgg.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: []"
+            checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+        }
+        checkAnswer(selectAgg, Seq(Row(2)))
+      }
+    }
+  }
+
+  test("aggregate push down - push down only if all the aggregates can be pushed down") {
+    val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+      (9, "mno", 7), (2, null, 7))
+    withParquetTable(data, "t") {
+      withSQLConf(
+        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        // not push down since sum can't be pushed down
+        val selectAgg = sql("SELECT min(_1), sum(_3) FROM t")
+        selectAgg.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: []"
+            checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+        }
+        checkAnswer(selectAgg, Seq(Row(-2, 41)))
+      }
+    }
+  }
+
+  test("aggregate push down - MIN/MAX/COUNT") {
+    val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+      (9, "mno", 7), (2, null, 6))
+    withParquetTable(data, "t") {
+      withSQLConf(
+        SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true") {
+        val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," +
+          " count(*), count(_1), count(_2), count(_3) FROM t")
+        selectAgg.queryExecution.optimizedPlan.collect {
+          case _: DataSourceV2ScanRelation =>
+            val expected_plan_fragment =
+              "PushedAggregation: [MIN(_3), " +
+                "MAX(_3), " +
+                "MIN(_1), " +
+                "MAX(_1), " +
+                "COUNT(*), " +
+                "COUNT(_1), " +
+                "COUNT(_2), " +
+                "COUNT(_3)]"
+            checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+        }
+
+        checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6)))
+      }
+    }
+  }
+
+  test("aggregate push down - different data types") {
+    implicit class StringToDate(s: String) {
+      def date: Date = Date.valueOf(s)
+    }
+
+    implicit class StringToTs(s: String) {
+      def ts: Timestamp = Timestamp.valueOf(s)
+    }
+
+    val rows =
+      Seq(
+        Row(
+          "a string",
+          true,
+          10.toByte,
+          "Spark SQL".getBytes,
+          12.toShort,
+          3,
+          Long.MaxValue,
+          0.15.toFloat,
+          0.75D,
+          Decimal("12.345678"),
+          ("2021-01-01").date,
+          ("2015-01-01 23:50:59.123").ts),
+        Row(
+          "test string",
+          false,
+          1.toByte,
+          "Parquet".getBytes,
+          2.toShort,
+          null,
+          Long.MinValue,
+          0.25.toFloat,
+          0.85D,
+          Decimal("1.2345678"),
+          ("2015-01-01").date,
+          ("2021-01-01 23:50:59.123").ts),
+        Row(
+          null,
+          true,
+          10000.toByte,
+          "Spark ML".getBytes,
+          222.toShort,
+          113,
+          11111111L,
+          0.25.toFloat,
+          0.75D,
+          Decimal("12345.678"),
+          ("2004-06-19").date,
+          ("1999-08-26 10:43:59.123").ts)
+      )
+
+    val schema = StructType(List(StructField("StringCol", StringType, true),
+      StructField("BooleanCol", BooleanType, false),
+      StructField("ByteCol", ByteType, false),
+      StructField("BinaryCol", BinaryType, false),
+      StructField("ShortCol", ShortType, false),
+      StructField("IntegerCol", IntegerType, true),
+      StructField("LongCol", LongType, false),
+      StructField("FloatCol", FloatType, false),
+      StructField("DoubleCol", DoubleType, false),
+      StructField("DecimalCol", DecimalType(25, 5), true),
+      StructField("DateCol", DateType, false),
+      StructField("TimestampCol", TimestampType, false)).toArray)
+
+    val rdd = sparkContext.parallelize(rows)
+    withTempPath { file =>
+      spark.createDataFrame(rdd, schema).write.parquet(file.getCanonicalPath)
+      withTempView("test") {
+        spark.read.parquet(file.getCanonicalPath).createOrReplaceTempView("test")
+        val enableVectorizedReader = Seq("false", "true")
+        for (testVectorizedReader <- enableVectorizedReader) {
+          withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+            vectorizedReaderEnabledKey -> testVectorizedReader) {
+
+            val testMinWithTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " +
+              "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " +
+              "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test")
+
+            // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type
+            // so aggregates are not pushed down
+            testMinWithTS.queryExecution.optimizedPlan.collect {
+              case _: DataSourceV2ScanRelation =>
+                val expected_plan_fragment =
+                  "PushedAggregation: []"
+                checkKeywordsExistsInExplain(testMinWithTS, expected_plan_fragment)
+            }
+
+            checkAnswer(testMinWithTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes,
+              2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457,
+              ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)))
+
+            val testMinWithOutTS = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " +
+              "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " +
+              "min(DoubleCol), min(DecimalCol), min(DateCol) FROM test")
+
+            testMinWithOutTS.queryExecution.optimizedPlan.collect {
+              case _: DataSourceV2ScanRelation =>
+                val expected_plan_fragment =
+                  "PushedAggregation: [MIN(StringCol), " +
+                    "MIN(BooleanCol), " +
+                    "MIN(ByteCol), " +
+                    "MIN(BinaryCol), " +
+                    "MIN(ShortCol), " +
+                    "MIN(IntegerCol), " +
+                    "MIN(LongCol), " +
+                    "MIN(FloatCol), " +
+                    "MIN(DoubleCol), " +
+                    "MIN(DecimalCol), " +
+                    "MIN(DateCol)]"
+                checkKeywordsExistsInExplain(testMinWithOutTS, expected_plan_fragment)
+            }
+
+            checkAnswer(testMinWithOutTS, Seq(Row("a string", false, 1.toByte, "Parquet".getBytes,
+              2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, 1.23457,
+              ("2004-06-19").date)))
+
+            val testMaxWithTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " +
+              "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " +
+              "max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) FROM test")
+
+            // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type
+            // so aggregates are not pushed down
+            testMaxWithTS.queryExecution.optimizedPlan.collect {
+              case _: DataSourceV2ScanRelation =>
+                val expected_plan_fragment =
+                  "PushedAggregation: []"
+                checkKeywordsExistsInExplain(testMaxWithTS, expected_plan_fragment)
+            }
+
+            checkAnswer(testMaxWithTS, Seq(Row("test string", true, 16.toByte,
+              "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D,
+              12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)))
+
+            val testMaxWithoutTS = sql("SELECT max(StringCol), max(BooleanCol), max(ByteCol), " +
+              "max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " +
+              "max(DoubleCol), max(DecimalCol), max(DateCol) FROM test")
+
+            testMaxWithoutTS.queryExecution.optimizedPlan.collect {
+              case _: DataSourceV2ScanRelation =>
+                val expected_plan_fragment =
+                  "PushedAggregation: [MAX(StringCol), " +
+                    "MAX(BooleanCol), " +
+                    "MAX(ByteCol), " +
+                    "MAX(BinaryCol), " +
+                    "MAX(ShortCol), " +
+                    "MAX(IntegerCol), " +
+                    "MAX(LongCol), " +
+                    "MAX(FloatCol), " +
+                    "MAX(DoubleCol), " +
+                    "MAX(DecimalCol), " +
+                    "MAX(DateCol)]"
+                checkKeywordsExistsInExplain(testMaxWithoutTS, expected_plan_fragment)
+            }
+
+            checkAnswer(testMaxWithoutTS, Seq(Row("test string", true, 16.toByte,
+              "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D,
+              12345.678, ("2021-01-01").date)))
+
+            val testCount = sql("SELECT count(StringCol), count(BooleanCol)," +
+              " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," +
+              " count(LongCol), count(FloatCol), count(DoubleCol)," +
+              " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test")
+
+            testCount.queryExecution.optimizedPlan.collect {
+              case _: DataSourceV2ScanRelation =>
+                val expected_plan_fragment =
+                  "PushedAggregation: [" +
+                    "COUNT(StringCol), " +
+                    "COUNT(BooleanCol), " +
+                    "COUNT(ByteCol), " +
+                    "COUNT(BinaryCol), " +
+                    "COUNT(ShortCol), " +
+                    "COUNT(IntegerCol), " +
+                    "COUNT(LongCol), " +
+                    "COUNT(FloatCol), " +
+                    "COUNT(DoubleCol), " +
+                    "COUNT(DecimalCol), " +
+                    "COUNT(DateCol), " +
+                    "COUNT(TimestampCol)]"
+                checkKeywordsExistsInExplain(testCount, expected_plan_fragment)
+            }
+
+            checkAnswer(testCount, Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3)))
+          }
+        }
+      }
+    }
+  }
+
+  test("aggregate push down - column name case sensitivity") {
+    val enableVectorizedReader = Seq("false", "true")
+    for (testVectorizedReader <- enableVectorizedReader) {
+      withSQLConf(SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key -> "true",
+        vectorizedReaderEnabledKey -> testVectorizedReader) {
+        withTempPath { dir =>
+          spark.range(10).selectExpr("id", "id % 3 as p")
+            .write.partitionBy("p").parquet(dir.getCanonicalPath)
+          withTempView("tmp") {
+            spark.read.parquet(dir.getCanonicalPath).createOrReplaceTempView("tmp");
+            val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp")
+            selectAgg.queryExecution.optimizedPlan.collect {
+              case _: DataSourceV2ScanRelation =>
+                val expected_plan_fragment =
+                  "PushedAggregation: [MAX(id), MIN(id)]"
+                checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+            }
+            checkAnswer(selectAgg, Seq(Row(9, 0)))
+          }
+        }
+      }
+    }
+  }
+}
+
+class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+  override protected def sparkConf: SparkConf =
+    super
+      .sparkConf
+      .set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
+}
+
+class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+  override protected def sparkConf: SparkConf =
+    super
+      .sparkConf
+      .set(SQLConf.USE_V1_SOURCE_LIST, "")
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org