You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/10/18 17:02:14 UTC

[GitHub] [spark] sunchao commented on a change in pull request #34298: [SPARK-34960][SQL] Aggregate push down for ORC

sunchao commented on a change in pull request #34298:
URL: https://github.com/apache/spark/pull/34298#discussion_r731103956



##########
File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import org.apache.orc.ColumnStatistics;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Columns statistics interface wrapping ORC {@link ColumnStatistics}s.
+ *
+ * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer,

Review comment:
       is it in pre-order and does it flatten all the nested types? might worth mention here

##########
File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java
##########
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import org.apache.orc.ColumnStatistics;
+import org.apache.orc.Reader;
+import org.apache.orc.TypeDescription;
+import org.apache.spark.sql.types.*;
+
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.Queue;
+
+/**
+ * `OrcFooterReader` is a util class which encapsulates the helper
+ * methods of reading ORC file footer.
+ */
+public class OrcFooterReader {
+
+  /**
+   * Read the columns statistics from ORC file footer.
+   *
+   * @param orcReader the reader to read ORC file footer.
+   * @return Statistics for all columns in the file.
+   */
+  public static OrcColumnsStatistics readStatistics(Reader orcReader) {
+    TypeDescription orcSchema = orcReader.getSchema();
+    ColumnStatistics[] orcStatistics = orcReader.getStatistics();
+    StructType dataType = OrcUtils.toCatalystSchema(orcSchema);

Review comment:
       nit: maybe rename `dataType` to `sparkSchema`.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+                "for DoubleColumnStatistics")
+          }
+        case s: DecimalColumnStatistics =>
+          new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum)
+        case s: StringColumnStatistics =>
+          new Text(if (isMax) s.getMaximum else s.getMinimum)
+        case s: DateColumnStatistics =>
+          new DateWritable(
+            if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt)
+        case _ => throw new IllegalArgumentException(
+          s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " +
+            s"$statistics as the ORC column statistics")
+      }
+    }
+
+    val aggORCValues: Seq[WritableComparable[_]] =
+      aggregation.aggregateExpressions.zipWithIndex.map {
+        case (max: Max, index) =>
+          val columnName = max.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+          value
+        case (min: Min, index) =>
+          val columnName = min.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema.apply(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+          value
+        case (count: Count, _) =>
+          val columnName = count.column.fieldNames.head
+          val isPartitionColumn = partitionSchema.fields
+            .map(PartitioningUtils.getColName(_, isCaseSensitive))
+            .contains(columnName)
+          // NOTE: Count(columnName) doesn't include null values.
+          // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values

Review comment:
       I got confused reading this: `getNumberOfValues` has two different behaviors depending on the column it represents is top-level or non-top-level?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+                "for DoubleColumnStatistics")
+          }
+        case s: DecimalColumnStatistics =>
+          new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum)
+        case s: StringColumnStatistics =>
+          new Text(if (isMax) s.getMaximum else s.getMinimum)
+        case s: DateColumnStatistics =>
+          new DateWritable(
+            if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt)
+        case _ => throw new IllegalArgumentException(
+          s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " +
+            s"$statistics as the ORC column statistics")
+      }
+    }
+
+    val aggORCValues: Seq[WritableComparable[_]] =
+      aggregation.aggregateExpressions.zipWithIndex.map {
+        case (max: Max, index) =>
+          val columnName = max.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+          value
+        case (min: Min, index) =>
+          val columnName = min.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema.apply(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+          value
+        case (count: Count, _) =>
+          val columnName = count.column.fieldNames.head
+          val isPartitionColumn = partitionSchema.fields
+            .map(PartitioningUtils.getColName(_, isCaseSensitive))
+            .contains(columnName)
+          // NOTE: Count(columnName) doesn't include null values.
+          // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values
+          // for ColumnStatistics of individual column. In addition to this, ORC also returns
+          // number of non-null and null values for its top-level
+          // ColumnStatistics.getNumberOfValues().
+          val nonNullRowsCount = if (isPartitionColumn) {
+            val topLevelStatistics = columnsStatistics.getStatistics
+            if (topLevelStatistics.hasNull) {
+              throw new SparkException(s"Illegal ORC top-level column statistics with NULL " +
+                s"values: $topLevelStatistics. Aggregate expression: $count")
+            }
+            topLevelStatistics.getNumberOfValues
+          } else {
+            getColumnStatistics(columnName).getNumberOfValues
+          }
+          new LongWritable(nonNullRowsCount)
+        case (_: CountStar, _) =>
+          // Count(*) includes both null and non-null values.
+          val topLevelStatistics = columnsStatistics.getStatistics
+          if (topLevelStatistics.hasNull) {
+            throw new SparkException(s"Illegal ORC top-level column statistics with NULL " +

Review comment:
       not sure why we should throw exception here - doesn't `count(*)` include NULLs?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
##########
@@ -37,35 +38,65 @@ case class OrcScan(
     readDataSchema: StructType,
     readPartitionSchema: StructType,
     options: CaseInsensitiveStringMap,
+    pushedAggregate: Option[Aggregation] = None,
     pushedFilters: Array[Filter],
     partitionFilters: Seq[Expression] = Seq.empty,
     dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
-  override def isSplitable(path: Path): Boolean = true
+  override def isSplitable(path: Path): Boolean = {
+    // If aggregate is pushed down, only the file footer will be read once,
+    // so file should be not split across multiple tasks.
+    pushedAggregate.isEmpty

Review comment:
       Seems this is a better approach than we are doing on Parquet side, cc @huaxingao . Also maybe we should change how we measure file weight when combining tasks for aggregate pushdown, since we can combine multiple large files into a single task as computing stats is much cheaper.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
+import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+/**
+ * Utility class for aggregate push down to Parquet and ORC.
+ */
+object AggregatePushDownUtils {
+
+  /**
+   * Get the data schema for aggregate to be pushed down.
+   */
+  def getSchemaForPushedAggregation(
+      aggregation: Aggregation,
+      schema: StructType,
+      partitionNameSet: Set[String],

Review comment:
       nit: maybe `partitionNames`?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
##########
@@ -87,84 +86,45 @@ case class ParquetScanBuilder(
   override def pushedFilters(): Array[Filter] = pushedParquetFilters
 
   override def pushAggregation(aggregation: Aggregation): Boolean = {
-
-    def getStructFieldForCol(col: NamedReference): StructField = {
-      schema.nameToField(col.fieldNames.head)
-    }
-
-    def isPartitionCol(col: NamedReference) = {
-      partitionNameSet.contains(col.fieldNames.head)
+    if (!sparkSession.sessionState.conf.parquetAggregatePushDown) {
+      return false
     }
 
-    def processMinOrMax(agg: AggregateFunc): Boolean = {
-      val (column, aggType) = agg match {
-        case max: Max => (max.column, "max")
-        case min: Min => (min.column, "min")
-        case _ =>
-          throw new IllegalArgumentException(s"Unexpected type of AggregateFunc ${agg.describe}")
-      }
-
-      if (isPartitionCol(column)) {
-        // don't push down partition column, footer doesn't have max/min for partition column
-        return false
-      }
-      val structField = getStructFieldForCol(column)
-
-      structField.dataType match {
-        // not push down complex type
-        // not push down Timestamp because INT96 sort order is undefined,
-        // Parquet doesn't return statistics for INT96
-        case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType =>
+    def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = {
+      dataType match {
+        // Not push down complex type.
+        // Not push down Timestamp because INT96 sort order is undefined,
+        // Parquet doesn't return statistics for INT96.
+        // Not push down Binary type as Parquet can truncate the statistics.
+        case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType =>

Review comment:
       We should put StringType here too

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
##########
@@ -58,4 +72,35 @@ case class OrcScanBuilder(
       Array.empty[Filter]
     }
   }
+
+  override def pushAggregation(aggregation: Aggregation): Boolean = {
+    if (!sparkSession.sessionState.conf.orcAggregatePushDown) {
+      return false
+    }
+
+    def isAllowedTypeForMinMaxAggregate(dataType: DataType): Boolean = {
+      dataType match {
+        // Not push down complex and Timestamp type.
+        // Not push down Binary type as ORC does not write min/max statistics for it.
+        case StructType(_) | ArrayType(_, _) | MapType(_, _, _) | TimestampType | BinaryType =>

Review comment:
       hm should we add `StringType`? here how does ORC store stats for long strings?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+                "for DoubleColumnStatistics")
+          }
+        case s: DecimalColumnStatistics =>
+          new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum)
+        case s: StringColumnStatistics =>
+          new Text(if (isMax) s.getMaximum else s.getMinimum)
+        case s: DateColumnStatistics =>
+          new DateWritable(
+            if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt)
+        case _ => throw new IllegalArgumentException(
+          s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " +
+            s"$statistics as the ORC column statistics")
+      }
+    }
+
+    val aggORCValues: Seq[WritableComparable[_]] =
+      aggregation.aggregateExpressions.zipWithIndex.map {
+        case (max: Max, index) =>
+          val columnName = max.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+          value
+        case (min: Min, index) =>
+          val columnName = min.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema.apply(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+          value
+        case (count: Count, _) =>
+          val columnName = count.column.fieldNames.head
+          val isPartitionColumn = partitionSchema.fields
+            .map(PartitioningUtils.getColName(_, isCaseSensitive))
+            .contains(columnName)
+          // NOTE: Count(columnName) doesn't include null values.
+          // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values
+          // for ColumnStatistics of individual column. In addition to this, ORC also returns
+          // number of non-null and null values for its top-level
+          // ColumnStatistics.getNumberOfValues().
+          val nonNullRowsCount = if (isPartitionColumn) {
+            val topLevelStatistics = columnsStatistics.getStatistics
+            if (topLevelStatistics.hasNull) {
+              throw new SparkException(s"Illegal ORC top-level column statistics with NULL " +

Review comment:
       hm does it mean here we have an invalid ORC file or it is a valid file but Spark can't handle the case?

##########
File path: sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnsStatistics.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import org.apache.orc.ColumnStatistics;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Columns statistics interface wrapping ORC {@link ColumnStatistics}s.
+ *
+ * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer,
+ * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure,
+ * according to data types. This is used for aggregate push down in ORC.
+ */
+public class OrcColumnsStatistics {

Review comment:
       `OrcColumnsStatistics` -> `OrcColumnStatistics`?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+                "for DoubleColumnStatistics")
+          }
+        case s: DecimalColumnStatistics =>
+          new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum)
+        case s: StringColumnStatistics =>
+          new Text(if (isMax) s.getMaximum else s.getMinimum)
+        case s: DateColumnStatistics =>
+          new DateWritable(
+            if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt)
+        case _ => throw new IllegalArgumentException(
+          s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " +
+            s"$statistics as the ORC column statistics")
+      }
+    }
+
+    val aggORCValues: Seq[WritableComparable[_]] =
+      aggregation.aggregateExpressions.zipWithIndex.map {
+        case (max: Max, index) =>
+          val columnName = max.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+          value

Review comment:
       nit: `value` seems redundant

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +

Review comment:
       nit: space at the end

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
+import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+/**
+ * Utility class for aggregate push down to Parquet and ORC.
+ */
+object AggregatePushDownUtils {
+
+  /**
+   * Get the data schema for aggregate to be pushed down.
+   */
+  def getSchemaForPushedAggregation(
+      aggregation: Aggregation,
+      schema: StructType,
+      partitionNameSet: Set[String],
+      dataFilters: Seq[Expression],
+      isAllowedTypeForMinMaxAggregate: DataType => Boolean,
+      sparkSession: SparkSession): Option[StructType] = {

Review comment:
       nit: `sparkSession` is unused.

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,

Review comment:
       nit: `filePath` is unused

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+                "for DoubleColumnStatistics")
+          }
+        case s: DecimalColumnStatistics =>
+          new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum)
+        case s: StringColumnStatistics =>

Review comment:
       is it OK to pushdown string for ORC? I remember you mentioned there is some issues similar to Parquet?

##########
File path: sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
##########
@@ -377,4 +381,124 @@ object OrcUtils extends Logging {
       case _ => false
     }
   }
+
+  /**
+   * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+   * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+   * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+   * construct an InternalRow from these aggregate results.
+   *
+   * @return Aggregate results in the format of InternalRow
+   */
+  def createAggInternalRowFromFooter(
+      reader: Reader,
+      filePath: String,
+      dataSchema: StructType,
+      partitionSchema: StructType,
+      aggregation: Aggregation,
+      aggSchema: StructType,
+      isCaseSensitive: Boolean): InternalRow = {
+    require(aggregation.groupByColumns.length == 0,
+      s"aggregate $aggregation with group-by column shouldn't be pushed down")
+    val columnsStatistics = OrcFooterReader.readStatistics(reader)
+
+    // Get column statistics with column name.
+    def getColumnStatistics(columnName: String): ColumnStatistics = {
+      val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+      columnsStatistics.get(columnIndex).getStatistics
+    }
+
+    // Get Min/Max statistics and store as ORC `WritableComparable` format.
+    def getMinMaxFromColumnStatistics(
+        statistics: ColumnStatistics,
+        dataType: DataType,
+        isMax: Boolean): WritableComparable[_] = {
+      statistics match {
+        case s: BooleanColumnStatistics =>
+          val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+          new BooleanWritable(value)
+        case s: IntegerColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case ByteType => new ByteWritable(value.toByte)
+            case ShortType => new ShortWritable(value.toShort)
+            case IntegerType => new IntWritable(value.toInt)
+            case LongType => new LongWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+              "for IntegerColumnStatistics")
+          }
+        case s: DoubleColumnStatistics =>
+          val value = if (isMax) s.getMaximum else s.getMinimum
+          dataType match {
+            case FloatType => new FloatWritable(value.toFloat)
+            case DoubleType => new DoubleWritable(value)
+            case _ => throw new IllegalArgumentException(
+              s"getMaxFromColumnStatistics should not take type $dataType" +
+                "for DoubleColumnStatistics")
+          }
+        case s: DecimalColumnStatistics =>
+          new HiveDecimalWritable(if (isMax) s.getMaximum else s.getMinimum)
+        case s: StringColumnStatistics =>
+          new Text(if (isMax) s.getMaximum else s.getMinimum)
+        case s: DateColumnStatistics =>
+          new DateWritable(
+            if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt)
+        case _ => throw new IllegalArgumentException(
+          s"getMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " +
+            s"$statistics as the ORC column statistics")
+      }
+    }
+
+    val aggORCValues: Seq[WritableComparable[_]] =
+      aggregation.aggregateExpressions.zipWithIndex.map {
+        case (max: Max, index) =>
+          val columnName = max.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+          value
+        case (min: Min, index) =>
+          val columnName = min.column.fieldNames.head
+          val statistics = getColumnStatistics(columnName)
+          val dataType = aggSchema.apply(index).dataType
+          val value = getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+          value
+        case (count: Count, _) =>
+          val columnName = count.column.fieldNames.head
+          val isPartitionColumn = partitionSchema.fields
+            .map(PartitioningUtils.getColName(_, isCaseSensitive))
+            .contains(columnName)
+          // NOTE: Count(columnName) doesn't include null values.
+          // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values
+          // for ColumnStatistics of individual column. In addition to this, ORC also returns
+          // number of non-null and null values for its top-level
+          // ColumnStatistics.getNumberOfValues().
+          val nonNullRowsCount = if (isPartitionColumn) {
+            val topLevelStatistics = columnsStatistics.getStatistics
+            if (topLevelStatistics.hasNull) {
+              throw new SparkException(s"Illegal ORC top-level column statistics with NULL " +

Review comment:
       I think we should also give an informative error message to the users so they know how to fallback 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org