You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/08/25 04:46:30 UTC
[spark] branch branch-3.0 updated:
[SPARK-32646][SQL][3.0][TEST-HADOOP2.7][TEST-HIVE1.2] ORC predicate
pushdown should work with case-insensitive analysis
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 6c88d7c [SPARK-32646][SQL][3.0][TEST-HADOOP2.7][TEST-HIVE1.2] ORC predicate pushdown should work with case-insensitive analysis
6c88d7c is described below
commit 6c88d7c1259ea9fe89f5c8190c683bba506d528e
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue Aug 25 04:42:39 2020 +0000
[SPARK-32646][SQL][3.0][TEST-HADOOP2.7][TEST-HIVE1.2] ORC predicate pushdown should work with case-insensitive analysis
### What changes were proposed in this pull request?
This PR proposes to fix ORC predicate pushdown under case-insensitive analysis case. The field names in pushed down predicates don't need to match in exact letter case with physical field names in ORC files, if we enable case-insensitive analysis.
### Why are the changes needed?
Currently ORC predicate pushdown doesn't work with case-insensitive analysis. A predicate "a < 0" cannot pushdown to ORC file with field name "A" under case-insensitive analysis.
But Parquet predicate pushdown works with this case. We should make ORC predicate pushdown work with case-insensitive analysis too.
### Does this PR introduce _any_ user-facing change?
Yes, after this PR, under case-insensitive analysis, ORC predicate pushdown will work.
### How was this patch tested?
Unit tests.
Closes #29513 from viirya/fix-orc-pushdown-3.0.
Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../execution/datasources/orc/OrcFileFormat.scala | 16 +++--
.../execution/datasources/orc/OrcFiltersBase.scala | 35 ++++++++++-
.../sql/execution/datasources/orc/OrcUtils.scala | 14 +++++
.../v2/orc/OrcPartitionReaderFactory.scala | 22 ++++++-
.../sql/execution/datasources/v2/orc/OrcScan.scala | 2 +-
.../datasources/v2/orc/OrcScanBuilder.scala | 9 +--
.../sql/execution/datasources/orc/OrcFilters.scala | 72 ++++++++++++----------
.../execution/datasources/orc/OrcFilterSuite.scala | 49 ++++++++++++++-
.../sql/execution/datasources/orc/OrcFilters.scala | 70 +++++++++++----------
.../execution/datasources/orc/OrcFilterSuite.scala | 49 ++++++++++++++-
10 files changed, 253 insertions(+), 85 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 4dff1ec..69badb4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -153,11 +153,6 @@ class OrcFileFormat
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
- if (sparkSession.sessionState.conf.orcFilterPushDown) {
- OrcFilters.createFilter(dataSchema, filters).foreach { f =>
- OrcInputFormat.setSearchArgument(hadoopConf, f, dataSchema.fieldNames)
- }
- }
val resultSchema = StructType(requiredSchema.fields ++ partitionSchema.fields)
val sqlConf = sparkSession.sessionState.conf
@@ -169,6 +164,8 @@ class OrcFileFormat
val broadcastedConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
+ val orcFilterPushDown = sparkSession.sessionState.conf.orcFilterPushDown
+ val ignoreCorruptFiles = sparkSession.sessionState.conf.ignoreCorruptFiles
(file: PartitionedFile) => {
val conf = broadcastedConf.value.value
@@ -186,6 +183,15 @@ class OrcFileFormat
if (resultedColPruneInfo.isEmpty) {
Iterator.empty
} else {
+ // ORC predicate pushdown
+ if (orcFilterPushDown) {
+ OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
+ OrcFilters.createFilter(fileSchema, filters).foreach { f =>
+ OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
+ }
+ }
+ }
+
val (requestedColIds, canPruneCols) = resultedColPruneInfo.get
val resultSchemaString = OrcUtils.orcResultSchemaString(canPruneCols,
dataSchema, resultSchema, partitionSchema, conf)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala
index e673309..4554899 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala
@@ -17,14 +17,45 @@
package org.apache.spark.sql.execution.datasources.orc
+import java.util.Locale
+
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.sources.{And, Filter}
-import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType}
+import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructType}
/**
* Methods that can be shared when upgrading the built-in Hive.
*/
trait OrcFiltersBase {
+ case class OrcPrimitiveField(fieldName: String, fieldType: DataType)
+
+ protected[sql] def getDataTypeMap(
+ schema: StructType,
+ caseSensitive: Boolean): Map[String, OrcPrimitiveField] = {
+ val fields = schema.flatMap { f =>
+ if (isSearchableType(f.dataType)) {
+ Some(quoteIfNeeded(f.name) -> OrcPrimitiveField(quoteIfNeeded(f.name), f.dataType))
+ } else {
+ None
+ }
+ }
+
+ if (caseSensitive) {
+ fields.toMap
+ } else {
+ // Don't consider ambiguity here, i.e. more than one field are matched in case insensitive
+ // mode, just skip pushdown for these fields, they will trigger Exception when reading,
+ // See: SPARK-25175.
+ val dedupPrimitiveFields = fields
+ .groupBy(_._1.toLowerCase(Locale.ROOT))
+ .filter(_._2.size == 1)
+ .mapValues(_.head._2)
+ CaseInsensitiveMap(dedupPrimitiveFields)
+ }
+ }
+
private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = {
filters match {
case Seq() => None
@@ -40,7 +71,7 @@ trait OrcFiltersBase {
* Return true if this is a searchable type in ORC.
* Both CharType and VarcharType are cleaned at AstBuilder.
*/
- protected[sql] def isSearchableType(dataType: DataType) = dataType match {
+ private def isSearchableType(dataType: DataType) = dataType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index e102539..be36432 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -92,6 +92,20 @@ object OrcUtils extends Logging {
}
}
+ def readCatalystSchema(
+ file: Path,
+ conf: Configuration,
+ ignoreCorruptFiles: Boolean): Option[StructType] = {
+ readSchema(file, conf, ignoreCorruptFiles) match {
+ case Some(schema) =>
+ Some(CatalystSqlParser.parseDataType(schema.toString).asInstanceOf[StructType])
+
+ case None =>
+ // Field names is empty or `FileFormatException` was thrown but ignoreCorruptFiles is true.
+ None
+ }
+ }
+
/**
* Reads ORC file schemas in multi-threaded manner, using native version of ORC.
* This is visible for testing.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
index 7f25f7bd..1f38128 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
@@ -31,9 +31,10 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
import org.apache.spark.sql.execution.datasources.PartitionedFile
-import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcUtils}
+import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.{AtomicType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -52,10 +53,13 @@ case class OrcPartitionReaderFactory(
broadcastedConf: Broadcast[SerializableConfiguration],
dataSchema: StructType,
readDataSchema: StructType,
- partitionSchema: StructType) extends FilePartitionReaderFactory {
+ partitionSchema: StructType,
+ filters: Array[Filter]) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize
+ private val orcFilterPushDown = sqlConf.orcFilterPushDown
+ private val ignoreCorruptFiles = sqlConf.ignoreCorruptFiles
override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
@@ -63,6 +67,16 @@ case class OrcPartitionReaderFactory(
resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
}
+ private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = {
+ if (orcFilterPushDown) {
+ OrcUtils.readCatalystSchema(filePath, conf, ignoreCorruptFiles).map { fileSchema =>
+ OrcFilters.createFilter(fileSchema, filters).foreach { f =>
+ OrcInputFormat.setSearchArgument(conf, f, fileSchema.fieldNames)
+ }
+ }
+ }
+ }
+
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value
@@ -70,6 +84,8 @@ case class OrcPartitionReaderFactory(
val filePath = new Path(new URI(file.filePath))
+ pushDownPredicates(filePath, conf)
+
val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
@@ -116,6 +132,8 @@ case class OrcPartitionReaderFactory(
val filePath = new Path(new URI(file.filePath))
+ pushDownPredicates(filePath, conf)
+
val fs = filePath.getFileSystem(conf)
val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index 62894fa..35e3b1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -48,7 +48,7 @@ case class OrcScan(
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
- dataSchema, readDataSchema, readPartitionSchema)
+ dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
}
override def equals(obj: Any): Boolean = obj match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
index 9f40f5f..6a9cb25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
@@ -22,11 +22,11 @@ import scala.collection.JavaConverters._
import org.apache.orc.mapreduce.OrcInputFormat
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -55,12 +55,7 @@ case class OrcScanBuilder(
override def pushFilters(filters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
- OrcFilters.createFilter(schema, filters).foreach { f =>
- // The pushed filters will be set in `hadoopConf`. After that, we can simply use the
- // changed `hadoopConf` in executors.
- OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames)
- }
- val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
+ val dataTypeMap = OrcFilters.getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
_pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, newFilters).toArray
diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
index b685639..a068347 100644
--- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
+++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -27,7 +27,7 @@ import org.apache.orc.storage.serde2.io.HiveDecimalWritable
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
-import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._
@@ -68,7 +68,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* Create ORC filter as a SearchArgument instance.
*/
def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
- val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
+ val dataTypeMap = getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
// Combines all convertible filters using `And` to produce a single conjunction
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
@@ -83,7 +83,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
def convertibleFilters(
schema: StructType,
- dataTypeMap: Map[String, DataType],
+ dataTypeMap: Map[String, OrcPrimitiveField],
filters: Seq[Filter]): Seq[Filter] = {
import org.apache.spark.sql.sources._
@@ -141,7 +141,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
/**
* Get PredicateLeafType which is corresponding to the given DataType.
*/
- private def getPredicateLeafType(dataType: DataType) = dataType match {
+ private[sql] def getPredicateLeafType(dataType: DataType) = dataType match {
case BooleanType => PredicateLeaf.Type.BOOLEAN
case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
@@ -181,7 +181,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildSearchArgument(
- dataTypeMap: Map[String, DataType],
+ dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Builder = {
import org.apache.spark.sql.sources._
@@ -217,11 +217,11 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildLeafSearchArgument(
- dataTypeMap: Map[String, DataType],
+ dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Option[Builder] = {
def getType(attribute: String): PredicateLeaf.Type =
- getPredicateLeafType(dataTypeMap(attribute))
+ getPredicateLeafType(dataTypeMap(attribute).fieldType)
import org.apache.spark.sql.sources._
@@ -231,39 +231,47 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters
// in order to distinguish predicate pushdown for nested columns.
expression match {
- case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().equals(name, getType(name), castedValue).end())
+ case EqualTo(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())
+ case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())
+ case LessThan(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())
+ case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())
+ case GreaterThan(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startNot()
+ .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startNot().lessThan(name, getType(name), castedValue).end())
+ case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startNot()
+ .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
- Some(builder.startAnd().isNull(name, getType(name)).end())
+ case IsNull(name) if dataTypeMap.contains(name) =>
+ Some(builder.startAnd()
+ .isNull(dataTypeMap(name).fieldName, getType(name)).end())
- case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
- Some(builder.startNot().isNull(name, getType(name)).end())
+ case IsNotNull(name) if dataTypeMap.contains(name) =>
+ Some(builder.startNot()
+ .isNull(dataTypeMap(name).fieldName, getType(name)).end())
- case In(name, values) if isSearchableType(dataTypeMap(name)) =>
- val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
- Some(builder.startAnd().in(name, getType(name),
+ case In(name, values) if dataTypeMap.contains(name) =>
+ val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType))
+ Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
case _ => None
diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
index beb7232..a3c2343 100644
--- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
+++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
@@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._
import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument}
+import org.apache.orc.storage.ql.io.sarg.SearchArgumentFactory.newBuilder
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
@@ -542,8 +543,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1)))
val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0"))
- // TODO: ORC predicate pushdown should work under case-insensitive analysis.
- // assert(actual.count() == 1)
+ assert(actual.count() == 1)
}
}
@@ -562,5 +562,50 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
}
}
}
+
+ test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") {
+ import org.apache.spark.sql.sources._
+
+ def getOrcFilter(
+ schema: StructType,
+ filters: Seq[Filter],
+ caseSensitive: String): Option[SearchArgument] = {
+ var orcFilter: Option[SearchArgument] = None
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
+ orcFilter =
+ OrcFilters.createFilter(schema, filters)
+ }
+ orcFilter
+ }
+
+ def testFilter(
+ schema: StructType,
+ filters: Seq[Filter],
+ expected: SearchArgument): Unit = {
+ val caseSensitiveFilters = getOrcFilter(schema, filters, "true")
+ val caseInsensitiveFilters = getOrcFilter(schema, filters, "false")
+
+ assert(caseSensitiveFilters.isEmpty)
+ assert(caseInsensitiveFilters.isDefined)
+
+ assert(caseInsensitiveFilters.get.getLeaves().size() > 0)
+ assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size())
+ (0 until expected.getLeaves().size()).foreach { index =>
+ assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index))
+ }
+ }
+
+ val schema = StructType(Seq(StructField("cint", IntegerType)))
+ testFilter(schema, Seq(GreaterThan("CINT", 1)),
+ newBuilder.startNot()
+ .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
+ testFilter(schema, Seq(
+ And(GreaterThan("CINT", 1), EqualTo("Cint", 2))),
+ newBuilder.startAnd()
+ .startNot()
+ .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
+ .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
+ .`end`().build())
+ }
}
diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
index 4b64208..9f1927e 100644
--- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
+++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala
@@ -27,7 +27,7 @@ import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
-import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.quoteIfNeeded
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types._
@@ -68,7 +68,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* Create ORC filter as a SearchArgument instance.
*/
def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = {
- val dataTypeMap = schema.map(f => quoteIfNeeded(f.name) -> f.dataType).toMap
+ val dataTypeMap = getDataTypeMap(schema, SQLConf.get.caseSensitiveAnalysis)
// Combines all convertible filters using `And` to produce a single conjunction
// TODO (SPARK-25557): ORC doesn't support nested predicate pushdown, so they are removed.
val newFilters = filters.filter(!_.containsNestedColumn)
@@ -83,7 +83,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
def convertibleFilters(
schema: StructType,
- dataTypeMap: Map[String, DataType],
+ dataTypeMap: Map[String, OrcPrimitiveField],
filters: Seq[Filter]): Seq[Filter] = {
import org.apache.spark.sql.sources._
@@ -141,7 +141,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
/**
* Get PredicateLeafType which is corresponding to the given DataType.
*/
- private def getPredicateLeafType(dataType: DataType) = dataType match {
+ private[sql] def getPredicateLeafType(dataType: DataType) = dataType match {
case BooleanType => PredicateLeaf.Type.BOOLEAN
case ByteType | ShortType | IntegerType | LongType => PredicateLeaf.Type.LONG
case FloatType | DoubleType => PredicateLeaf.Type.FLOAT
@@ -181,7 +181,7 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildSearchArgument(
- dataTypeMap: Map[String, DataType],
+ dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Builder = {
import org.apache.spark.sql.sources._
@@ -217,11 +217,11 @@ private[sql] object OrcFilters extends OrcFiltersBase {
* @return the builder so far.
*/
private def buildLeafSearchArgument(
- dataTypeMap: Map[String, DataType],
+ dataTypeMap: Map[String, OrcPrimitiveField],
expression: Filter,
builder: Builder): Option[Builder] = {
def getType(attribute: String): PredicateLeaf.Type =
- getPredicateLeafType(dataTypeMap(attribute))
+ getPredicateLeafType(dataTypeMap(attribute).fieldType)
import org.apache.spark.sql.sources._
@@ -231,39 +231,45 @@ private[sql] object OrcFilters extends OrcFiltersBase {
// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters
// in order to distinguish predicate pushdown for nested columns.
expression match {
- case EqualTo(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().equals(name, getType(name), castedValue).end())
+ case EqualTo(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case EqualNullSafe(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().nullSafeEquals(name, getType(name), castedValue).end())
+ case EqualNullSafe(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case LessThan(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().lessThan(name, getType(name), castedValue).end())
+ case LessThan(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case LessThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startAnd().lessThanEquals(name, getType(name), castedValue).end())
+ case LessThanOrEqual(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startAnd()
+ .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case GreaterThan(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startNot().lessThanEquals(name, getType(name), castedValue).end())
+ case GreaterThan(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startNot()
+ .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case GreaterThanOrEqual(name, value) if isSearchableType(dataTypeMap(name)) =>
- val castedValue = castLiteralValue(value, dataTypeMap(name))
- Some(builder.startNot().lessThan(name, getType(name), castedValue).end())
+ case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) =>
+ val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType)
+ Some(builder.startNot()
+ .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end())
- case IsNull(name) if isSearchableType(dataTypeMap(name)) =>
- Some(builder.startAnd().isNull(name, getType(name)).end())
+ case IsNull(name) if dataTypeMap.contains(name) =>
+ Some(builder.startAnd().isNull(dataTypeMap(name).fieldName, getType(name)).end())
- case IsNotNull(name) if isSearchableType(dataTypeMap(name)) =>
- Some(builder.startNot().isNull(name, getType(name)).end())
+ case IsNotNull(name) if dataTypeMap.contains(name) =>
+ Some(builder.startNot().isNull(dataTypeMap(name).fieldName, getType(name)).end())
- case In(name, values) if isSearchableType(dataTypeMap(name)) =>
- val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name)))
- Some(builder.startAnd().in(name, getType(name),
+ case In(name, values) if dataTypeMap.contains(name) =>
+ val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType))
+ Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())
case _ => None
diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
index a3e450c..cb69413 100644
--- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
+++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala
@@ -24,6 +24,7 @@ import java.sql.{Date, Timestamp}
import scala.collection.JavaConverters._
import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument}
+import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row}
@@ -543,8 +544,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
checkAnswer(sql(s"select a from $tableName"), (0 until count).map(c => Row(c - 1)))
val actual = stripSparkFilter(sql(s"select a from $tableName where a < 0"))
- // TODO: ORC predicate pushdown should work under case-insensitive analysis.
- // assert(actual.count() == 1)
+ assert(actual.count() == 1)
}
}
@@ -563,5 +563,50 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
}
}
}
+
+ test("SPARK-32646: Case-insensitive field resolution for pushdown when reading ORC") {
+ import org.apache.spark.sql.sources._
+
+ def getOrcFilter(
+ schema: StructType,
+ filters: Seq[Filter],
+ caseSensitive: String): Option[SearchArgument] = {
+ var orcFilter: Option[SearchArgument] = None
+ withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
+ orcFilter =
+ OrcFilters.createFilter(schema, filters)
+ }
+ orcFilter
+ }
+
+ def testFilter(
+ schema: StructType,
+ filters: Seq[Filter],
+ expected: SearchArgument): Unit = {
+ val caseSensitiveFilters = getOrcFilter(schema, filters, "true")
+ val caseInsensitiveFilters = getOrcFilter(schema, filters, "false")
+
+ assert(caseSensitiveFilters.isEmpty)
+ assert(caseInsensitiveFilters.isDefined)
+
+ assert(caseInsensitiveFilters.get.getLeaves().size() > 0)
+ assert(caseInsensitiveFilters.get.getLeaves().size() == expected.getLeaves().size())
+ (0 until expected.getLeaves().size()).foreach { index =>
+ assert(caseInsensitiveFilters.get.getLeaves().get(index) == expected.getLeaves().get(index))
+ }
+ }
+
+ val schema = StructType(Seq(StructField("cint", IntegerType)))
+ testFilter(schema, Seq(GreaterThan("CINT", 1)),
+ newBuilder.startNot()
+ .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`().build())
+ testFilter(schema, Seq(
+ And(GreaterThan("CINT", 1), EqualTo("Cint", 2))),
+ newBuilder.startAnd()
+ .startNot()
+ .lessThanEquals("cint", OrcFilters.getPredicateLeafType(IntegerType), 1L).`end`()
+ .equals("cint", OrcFilters.getPredicateLeafType(IntegerType), 2L)
+ .`end`().build())
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org