You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2019/03/12 22:39:39 UTC

[spark] branch master updated: [SPARK-27034][SQL] Nested schema pruning for ORC

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b0c2b3b  [SPARK-27034][SQL] Nested schema pruning for ORC
b0c2b3b is described below

commit b0c2b3bfd9b43ab97b37532abfef22e14642125c
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Tue Mar 12 15:39:16 2019 -0700

    [SPARK-27034][SQL] Nested schema pruning for ORC
    
    ## What changes were proposed in this pull request?
    
    We only supported nested schema pruning for Parquet previously. This proposes to support nested schema pruning for ORC too.
    
    Note: This only covers ORC v1. For ORC v2, the necessary change is at the schema pruning rule. We should deal with ORC v2 as a TODO item, in order to reduce review burden.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #23943 from viirya/nested-schema-pruning-orc.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |   4 +-
 .../OrcNestedSchemaPruningBenchmark-results.txt    |  20 +-
 .../execution/datasources/orc/OrcFileFormat.scala  |   9 +-
 .../sql/execution/datasources/orc/OrcUtils.scala   |  35 +-
 .../datasources/parquet/ParquetSchemaPruning.scala |   4 +-
 .../v2/orc/OrcPartitionReaderFactory.scala         |  23 +-
 ...PruningSuite.scala => SchemaPruningSuite.scala} |  64 ++--
 .../datasources/orc/OrcSchemaPruningSuite.scala    |  34 ++
 .../parquet/ParquetSchemaPruningSuite.scala        | 389 +--------------------
 9 files changed, 142 insertions(+), 440 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 6f483a7..193d311 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1539,8 +1539,8 @@ object SQLConf {
       .internal()
       .doc("Prune nested fields from a logical relation's output which are unnecessary in " +
         "satisfying a query. This optimization allows columnar file format readers to avoid " +
-        "reading unnecessary nested column data. Currently Parquet is the only data source that " +
-        "implements this optimization.")
+        "reading unnecessary nested column data. Currently Parquet and ORC v1 are the " +
+        "data sources that implement this optimization.")
       .booleanConf
       .createWithDefault(false)
 
diff --git a/sql/core/benchmarks/OrcNestedSchemaPruningBenchmark-results.txt b/sql/core/benchmarks/OrcNestedSchemaPruningBenchmark-results.txt
index f738256..fdd35cd 100644
--- a/sql/core/benchmarks/OrcNestedSchemaPruningBenchmark-results.txt
+++ b/sql/core/benchmarks/OrcNestedSchemaPruningBenchmark-results.txt
@@ -6,35 +6,35 @@ Java HotSpot(TM) 64-Bit Server VM 1.8.0_202-b08 on Mac OS X 10.14.3
 Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
 Selection:                                Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
 ------------------------------------------------------------------------------------------------------------------------
-Top-level column                                    113            196          89          8.8         113.0       1.0X
-Nested column                                      1316           1639         240          0.8        1315.5       0.1X
+Top-level column                                    116            151          36          8.6         116.3       1.0X
+Nested column                                       544            604          31          1.8         544.5       0.2X
 
 Java HotSpot(TM) 64-Bit Server VM 1.8.0_202-b08 on Mac OS X 10.14.3
 Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
 Limiting:                                 Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
 ------------------------------------------------------------------------------------------------------------------------
-Top-level column                                    260            474         211          3.8         260.4       1.0X
-Nested column                                      2322           3312         701          0.4        2322.3       0.1X
+Top-level column                                    360            397          32          2.8         360.4       1.0X
+Nested column                                      3322           3503         166          0.3        3322.4       0.1X
 
 Java HotSpot(TM) 64-Bit Server VM 1.8.0_202-b08 on Mac OS X 10.14.3
 Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
 Repartitioning:                           Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
 ------------------------------------------------------------------------------------------------------------------------
-Top-level column                                    275            318          55          3.6         274.8       1.0X
-Nested column                                      2482           3263         759          0.4        2482.2       0.1X
+Top-level column                                    292            334          32          3.4         291.8       1.0X
+Nested column                                      3306           3489         123          0.3        3305.7       0.1X
 
 Java HotSpot(TM) 64-Bit Server VM 1.8.0_202-b08 on Mac OS X 10.14.3
 Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
 Repartitioning by exprs:                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
 ------------------------------------------------------------------------------------------------------------------------
-Top-level column                                    274            288          11          3.7         273.9       1.0X
-Nested column                                      2783           2905          86          0.4        2782.7       0.1X
+Top-level column                                    302            333          27          3.3         302.0       1.0X
+Nested column                                      2697           3347         390          0.4        2697.4       0.1X
 
 Java HotSpot(TM) 64-Bit Server VM 1.8.0_202-b08 on Mac OS X 10.14.3
 Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz
 Sorting:                                  Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
 ------------------------------------------------------------------------------------------------------------------------
-Top-level column                                    382            419          23          2.6         382.4       1.0X
-Nested column                                      2974           3517         699          0.3        2974.1       0.1X
+Top-level column                                    316            440         146          3.2         315.8       1.0X
+Nested column                                      2728           2928         205          0.4        2727.9       0.1X
 
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index 2a76495..01f8ce7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.SerializableConfiguration
@@ -164,6 +165,10 @@ class OrcFileFormat
     val enableVectorizedReader = supportBatch(sparkSession, resultSchema)
     val capacity = sqlConf.orcVectorizedReaderBatchSize
 
+    val resultSchemaString = OrcUtils.orcTypeDescriptionString(resultSchema)
+    OrcConf.MAPRED_INPUT_SCHEMA.setString(hadoopConf, resultSchemaString)
+    OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, sqlConf.caseSensitiveAnalysis)
+
     val broadcastedConf =
       sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
     val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
@@ -187,8 +192,6 @@ class OrcFileFormat
         assert(requestedColIds.length == requiredSchema.length,
           "[BUG] requested column IDs do not match required schema")
         val taskConf = new Configuration(conf)
-        taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute,
-          requestedColIds.filter(_ != -1).sorted.mkString(","))
 
         val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
         val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
@@ -206,7 +209,7 @@ class OrcFileFormat
             Array.fill(requiredSchema.length)(-1) ++ Range(0, partitionSchema.length)
           batchReader.initialize(fileSplit, taskAttemptContext)
           batchReader.initBatch(
-            reader.getSchema,
+            TypeDescription.fromString(resultSchemaString),
             resultSchema.fields,
             requestedDataColIds,
             requestedPartitionColIds,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index 57d2c56..fb9f87c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession}
 import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.util.quoteIdentifier
 import org.apache.spark.sql.types._
 
 object OrcUtils extends Logging {
@@ -120,24 +121,27 @@ object OrcUtils extends Logging {
         })
       } else {
         if (isCaseSensitive) {
-          Some(requiredSchema.fieldNames.map { name =>
-            orcFieldNames.indexWhere(caseSensitiveResolution(_, name))
+          Some(requiredSchema.fieldNames.zipWithIndex.map { case (name, idx) =>
+            if (orcFieldNames.indexWhere(caseSensitiveResolution(_, name)) != -1) {
+              idx
+            } else {
+              -1
+            }
           })
         } else {
           // Do case-insensitive resolution only if in case-insensitive mode
-          val caseInsensitiveOrcFieldMap =
-            orcFieldNames.zipWithIndex.groupBy(_._1.toLowerCase(Locale.ROOT))
-          Some(requiredSchema.fieldNames.map { requiredFieldName =>
+          val caseInsensitiveOrcFieldMap = orcFieldNames.groupBy(_.toLowerCase(Locale.ROOT))
+          Some(requiredSchema.fieldNames.zipWithIndex.map { case (requiredFieldName, idx) =>
             caseInsensitiveOrcFieldMap
               .get(requiredFieldName.toLowerCase(Locale.ROOT))
               .map { matchedOrcFields =>
                 if (matchedOrcFields.size > 1) {
                   // Need to fail if there is ambiguity, i.e. more than one field is matched.
-                  val matchedOrcFieldsString = matchedOrcFields.map(_._1).mkString("[", ", ", "]")
+                  val matchedOrcFieldsString = matchedOrcFields.mkString("[", ", ", "]")
                   throw new RuntimeException(s"""Found duplicate field(s) "$requiredFieldName": """
                     + s"$matchedOrcFieldsString in case-insensitive mode")
                 } else {
-                  matchedOrcFields.head._2
+                  idx
                 }
               }.getOrElse(-1)
           })
@@ -152,4 +156,21 @@ object OrcUtils extends Logging {
   def addSparkVersionMetadata(writer: Writer): Unit = {
     writer.addUserMetadata(SPARK_VERSION_METADATA_KEY, UTF_8.encode(SPARK_VERSION_SHORT))
   }
+
+  /**
+   * Given a `StructType` object, this methods converts it to corresponding string representation
+   * in ORC.
+   */
+  def orcTypeDescriptionString(dt: DataType): String = dt match {
+    case s: StructType =>
+      val fieldTypes = s.fields.map { f =>
+        s"${quoteIdentifier(f.name)}:${orcTypeDescriptionString(f.dataType)}"
+      }
+      s"struct<${fieldTypes.mkString(",")}>"
+    case a: ArrayType =>
+      s"array<${orcTypeDescriptionString(a.elementType)}>"
+    case m: MapType =>
+      s"map<${orcTypeDescriptionString(m.keyType)},${orcTypeDescriptionString(m.valueType)}>"
+    case _ => dt.catalogString
+  }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala
index cc33db9..47551a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
 
@@ -81,7 +82,8 @@ object ParquetSchemaPruning extends Rule[LogicalPlan] {
    * Checks to see if the given relation is Parquet and can be pruned.
    */
   private def canPruneRelation(fsRelation: HadoopFsRelation) =
-    fsRelation.fileFormat.isInstanceOf[ParquetFileFormat]
+    fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
+      fsRelation.fileFormat.isInstanceOf[OrcFileFormat]
 
   /**
    * Normalizes the names of the attribute references in the given projects and filters to reflect
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
index 4ae10a6..1da9469 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
@@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
 import org.apache.hadoop.mapreduce.lib.input.FileSplit
 import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
-import org.apache.orc.{OrcConf, OrcFile}
+import org.apache.orc.{OrcConf, OrcFile, TypeDescription}
 import org.apache.orc.mapred.OrcStruct
 import org.apache.orc.mapreduce.OrcInputFormat
 
@@ -67,6 +67,11 @@ case class OrcPartitionReaderFactory(
   override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
     val conf = broadcastedConf.value.value
 
+    val readDataSchema = getReadDataSchema(readSchema, partitionSchema, isCaseSensitive)
+    val readDataSchemaString = OrcUtils.orcTypeDescriptionString(readDataSchema)
+    OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readDataSchemaString)
+    OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
+
     val filePath = new Path(new URI(file.filePath))
 
     val fs = filePath.getFileSystem(conf)
@@ -74,23 +79,21 @@ case class OrcPartitionReaderFactory(
     val reader = OrcFile.createReader(filePath, readerOptions)
 
     val requestedColIdsOrEmptyFile = OrcUtils.requestedColumnIds(
-      isCaseSensitive, dataSchema, readSchema, reader, conf)
+      isCaseSensitive, dataSchema, readDataSchema, reader, conf)
 
     if (requestedColIdsOrEmptyFile.isEmpty) {
       new EmptyPartitionReader[InternalRow]
     } else {
       val requestedColIds = requestedColIdsOrEmptyFile.get
-      assert(requestedColIds.length == readSchema.length,
+      assert(requestedColIds.length == readDataSchema.length,
         "[BUG] requested column IDs do not match required schema")
+
       val taskConf = new Configuration(conf)
-      taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute,
-        requestedColIds.filter(_ != -1).sorted.mkString(","))
 
       val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
       val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
       val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)
 
-      val readDataSchema = getReadDataSchema(readSchema, partitionSchema, isCaseSensitive)
       val orcRecordReader = new OrcInputFormat[OrcStruct]
         .createRecordReader(fileSplit, taskAttemptContext)
       val deserializer = new OrcDeserializer(dataSchema, readDataSchema, requestedColIds)
@@ -110,6 +113,10 @@ case class OrcPartitionReaderFactory(
   override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
     val conf = broadcastedConf.value.value
 
+    val readSchemaString = OrcUtils.orcTypeDescriptionString(readSchema)
+    OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, readSchemaString)
+    OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
+
     val filePath = new Path(new URI(file.filePath))
 
     val fs = filePath.getFileSystem(conf)
@@ -126,8 +133,6 @@ case class OrcPartitionReaderFactory(
       assert(requestedColIds.length == readSchema.length,
         "[BUG] requested column IDs do not match required schema")
       val taskConf = new Configuration(conf)
-      taskConf.set(OrcConf.INCLUDE_COLUMNS.getAttribute,
-        requestedColIds.filter(_ != -1).sorted.mkString(","))
 
       val fileSplit = new FileSplit(filePath, file.start, file.length, Array.empty)
       val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
@@ -142,7 +147,7 @@ case class OrcPartitionReaderFactory(
       }
 
       batchReader.initBatch(
-        reader.getSchema,
+        TypeDescription.fromString(readSchemaString),
         readSchema.fields,
         requestedColIds,
         requestedPartitionColIds,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
similarity index 85%
copy from sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
copy to sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index 4d15f38..d328ef4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -15,7 +15,7 @@
  * limitations under the License.
  */
 
-package org.apache.spark.sql.execution.datasources.parquet
+package org.apache.spark.sql.execution.datasources
 
 import java.io.File
 
@@ -30,11 +30,11 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.StructType
 
-class ParquetSchemaPruningSuite
-    extends QueryTest
-    with ParquetTest
-    with SchemaPruningTest
-    with SharedSQLContext {
+abstract class SchemaPruningSuite
+  extends QueryTest
+  with FileBasedDataSourceTest
+  with SchemaPruningTest
+  with SharedSQLContext {
   case class FullName(first: String, middle: String, last: String)
   case class Company(name: String, address: String)
   case class Employer(id: Int, company: Company)
@@ -54,7 +54,7 @@ class ParquetSchemaPruningSuite
   val employer = Employer(0, Company("abc", "123 Business Street"))
   val employerWithNullCompany = Employer(1, null)
 
-  private val contacts =
+  val contacts =
     Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith),
       relatives = Map("brother" -> johnDoe), employer = employer) ::
     Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe),
@@ -79,10 +79,10 @@ class ParquetSchemaPruningSuite
 
   case class BriefContactWithDataPartitionColumn(id: Int, name: Name, address: String, p: Int)
 
-  private val contactsWithDataPartitionColumn =
+  val contactsWithDataPartitionColumn =
     contacts.map { case Contact(id, name, address, pets, friends, relatives, employer) =>
       ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, employer, 1) }
-  private val briefContactsWithDataPartitionColumn =
+  val briefContactsWithDataPartitionColumn =
     briefContacts.map { case BriefContact(id, name, address) =>
       BriefContactWithDataPartitionColumn(id, name, address, 2) }
 
@@ -253,25 +253,25 @@ class ParquetSchemaPruningSuite
     checkAnswer(query, Row(1) :: Nil)
   }
 
-  private def testSchemaPruning(testName: String)(testThunk: => Unit) {
+  protected def testSchemaPruning(testName: String)(testThunk: => Unit) {
     test(s"Spark vectorized reader - without partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+      withSQLConf(vectorizedReaderEnabledKey -> "true") {
         withContacts(testThunk)
       }
     }
     test(s"Spark vectorized reader - with partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
+      withSQLConf(vectorizedReaderEnabledKey -> "true") {
         withContactsWithDataPartitionColumn(testThunk)
       }
     }
 
-    test(s"Parquet-mr reader - without partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+    test(s"Non-vectorized reader - without partition data column - $testName") {
+      withSQLConf(vectorizedReaderEnabledKey -> "false") {
         withContacts(testThunk)
       }
     }
-    test(s"Parquet-mr reader - with partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+    test(s"Non-vectorized reader - with partition data column - $testName") {
+      withSQLConf(vectorizedReaderEnabledKey-> "false") {
         withContactsWithDataPartitionColumn(testThunk)
       }
     }
@@ -281,10 +281,18 @@ class ParquetSchemaPruningSuite
     withTempPath { dir =>
       val path = dir.getCanonicalPath
 
-      makeParquetFile(contacts, new File(path + "/contacts/p=1"))
-      makeParquetFile(briefContacts, new File(path + "/contacts/p=2"))
+      makeDataSourceFile(contacts, new File(path + "/contacts/p=1"))
+      makeDataSourceFile(briefContacts, new File(path + "/contacts/p=2"))
 
-      spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts")
+      // Providing user specified schema. Inferred schema from different data sources might
+      // be different.
+      val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " +
+        "`address` STRING,`pets` INT,`friends` ARRAY<STRUCT<`first`: STRING, `middle`: STRING, " +
+        "`last`: STRING>>,`relatives` MAP<STRING, STRUCT<`first`: STRING, `middle`: STRING, " +
+        "`last`: STRING>>,`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, " +
+        "`address`: STRING>>,`p` INT"
+      spark.read.format(dataSourceName).schema(schema).load(path + "/contacts")
+        .createOrReplaceTempView("contacts")
 
       testThunk
     }
@@ -294,10 +302,18 @@ class ParquetSchemaPruningSuite
     withTempPath { dir =>
       val path = dir.getCanonicalPath
 
-      makeParquetFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1"))
-      makeParquetFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2"))
+      makeDataSourceFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1"))
+      makeDataSourceFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2"))
 
-      spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts")
+      // Providing user specified schema. Inferred schema from different data sources might
+      // be different.
+      val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " +
+        "`address` STRING,`pets` INT,`friends` ARRAY<STRUCT<`first`: STRING, `middle`: STRING, " +
+        "`last`: STRING>>,`relatives` MAP<STRING, STRUCT<`first`: STRING, `middle`: STRING, " +
+        "`last`: STRING>>,`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, " +
+        "`address`: STRING>>,`p` INT"
+      spark.read.format(dataSourceName).schema(schema).load(path + "/contacts")
+        .createOrReplaceTempView("contacts")
 
       testThunk
     }
@@ -366,9 +382,9 @@ class ParquetSchemaPruningSuite
     }
   }
 
-  // Tests given test function with Spark vectorized reader and Parquet-mr reader.
+  // Tests given test function with Spark vectorized reader and non-vectorized reader.
   private def withMixedCaseData(testThunk: => Unit) {
-    withParquetTable(mixedCaseData, "mixedcase") {
+    withDataSourceTable(mixedCaseData, "mixedcase") {
       testThunk
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
new file mode 100644
index 0000000..5dade6f
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSchemaPruningSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
+import org.apache.spark.sql.internal.SQLConf
+
+class OrcSchemaPruningSuite extends SchemaPruningSuite {
+  override protected val dataSourceName: String = "orc"
+  override protected val vectorizedReaderEnabledKey: String =
+    SQLConf.ORC_VECTORIZED_READER_ENABLED.key
+
+  override protected def sparkConf: SparkConf =
+    super
+      .sparkConf
+      .set(SQLConf.USE_V1_SOURCE_READER_LIST, "orc")
+      .set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "orc")
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
index 4d15f38..3d97d64 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
@@ -17,390 +17,11 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
-import java.io.File
-
-import org.scalactic.Equality
-
-import org.apache.spark.sql.{DataFrame, QueryTest, Row}
-import org.apache.spark.sql.catalyst.SchemaPruningTest
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
-import org.apache.spark.sql.execution.FileSourceScanExec
-import org.apache.spark.sql.functions._
+import org.apache.spark.sql.execution.datasources.SchemaPruningSuite
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.StructType
-
-class ParquetSchemaPruningSuite
-    extends QueryTest
-    with ParquetTest
-    with SchemaPruningTest
-    with SharedSQLContext {
-  case class FullName(first: String, middle: String, last: String)
-  case class Company(name: String, address: String)
-  case class Employer(id: Int, company: Company)
-  case class Contact(
-    id: Int,
-    name: FullName,
-    address: String,
-    pets: Int,
-    friends: Array[FullName] = Array.empty,
-    relatives: Map[String, FullName] = Map.empty,
-    employer: Employer = null)
-
-  val janeDoe = FullName("Jane", "X.", "Doe")
-  val johnDoe = FullName("John", "Y.", "Doe")
-  val susanSmith = FullName("Susan", "Z.", "Smith")
-
-  val employer = Employer(0, Company("abc", "123 Business Street"))
-  val employerWithNullCompany = Employer(1, null)
-
-  private val contacts =
-    Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith),
-      relatives = Map("brother" -> johnDoe), employer = employer) ::
-    Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe),
-      employer = employerWithNullCompany) :: Nil
-
-  case class Name(first: String, last: String)
-  case class BriefContact(id: Int, name: Name, address: String)
-
-  private val briefContacts =
-    BriefContact(2, Name("Janet", "Jones"), "567 Maple Drive") ::
-    BriefContact(3, Name("Jim", "Jones"), "6242 Ash Street") :: Nil
-
-  case class ContactWithDataPartitionColumn(
-    id: Int,
-    name: FullName,
-    address: String,
-    pets: Int,
-    friends: Array[FullName] = Array(),
-    relatives: Map[String, FullName] = Map(),
-    employer: Employer = null,
-    p: Int)
-
-  case class BriefContactWithDataPartitionColumn(id: Int, name: Name, address: String, p: Int)
-
-  private val contactsWithDataPartitionColumn =
-    contacts.map { case Contact(id, name, address, pets, friends, relatives, employer) =>
-      ContactWithDataPartitionColumn(id, name, address, pets, friends, relatives, employer, 1) }
-  private val briefContactsWithDataPartitionColumn =
-    briefContacts.map { case BriefContact(id, name, address) =>
-      BriefContactWithDataPartitionColumn(id, name, address, 2) }
-
-  testSchemaPruning("select a single complex field") {
-    val query = sql("select name.middle from contacts")
-    checkScan(query, "struct<name:struct<middle:string>>")
-    checkAnswer(query.orderBy("id"), Row("X.") :: Row("Y.") :: Row(null) :: Row(null) :: Nil)
-  }
-
-  testSchemaPruning("select a single complex field and its parent struct") {
-    val query = sql("select name.middle, name from contacts")
-    checkScan(query, "struct<name:struct<first:string,middle:string,last:string>>")
-    checkAnswer(query.orderBy("id"),
-      Row("X.", Row("Jane", "X.", "Doe")) ::
-      Row("Y.", Row("John", "Y.", "Doe")) ::
-      Row(null, Row("Janet", null, "Jones")) ::
-      Row(null, Row("Jim", null, "Jones")) ::
-      Nil)
-  }
-
-  testSchemaPruning("select a single complex field array and its parent struct array") {
-    val query = sql("select friends.middle, friends from contacts where p=1")
-    checkScan(query,
-      "struct<friends:array<struct<first:string,middle:string,last:string>>>")
-    checkAnswer(query.orderBy("id"),
-      Row(Array("Z."), Array(Row("Susan", "Z.", "Smith"))) ::
-      Row(Array.empty[String], Array.empty[Row]) ::
-      Nil)
-  }
-
-  testSchemaPruning("select a single complex field from a map entry and its parent map entry") {
-    val query =
-      sql("select relatives[\"brother\"].middle, relatives[\"brother\"] from contacts where p=1")
-    checkScan(query,
-      "struct<relatives:map<string,struct<first:string,middle:string,last:string>>>")
-    checkAnswer(query.orderBy("id"),
-      Row("Y.", Row("John", "Y.", "Doe")) ::
-      Row(null, null) ::
-      Nil)
-  }
-
-  testSchemaPruning("select a single complex field and the partition column") {
-    val query = sql("select name.middle, p from contacts")
-    checkScan(query, "struct<name:struct<middle:string>>")
-    checkAnswer(query.orderBy("id"),
-      Row("X.", 1) :: Row("Y.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil)
-  }
-
-  ignore("partial schema intersection - select missing subfield") {
-    val query = sql("select name.middle, address from contacts where p=2")
-    checkScan(query, "struct<name:struct<middle:string>,address:string>")
-    checkAnswer(query.orderBy("id"),
-      Row(null, "567 Maple Drive") ::
-      Row(null, "6242 Ash Street") :: Nil)
-  }
-
-  testSchemaPruning("no unnecessary schema pruning") {
-    val query =
-      sql("select id, name.last, name.middle, name.first, relatives[''].last, " +
-        "relatives[''].middle, relatives[''].first, friends[0].last, friends[0].middle, " +
-        "friends[0].first, pets, address from contacts where p=2")
-    // We've selected every field in the schema. Therefore, no schema pruning should be performed.
-    // We check this by asserting that the scanned schema of the query is identical to the schema
-    // of the contacts relation, even though the fields are selected in different orders.
-    checkScan(query,
-      "struct<id:int,name:struct<first:string,middle:string,last:string>,address:string,pets:int," +
-      "friends:array<struct<first:string,middle:string,last:string>>," +
-      "relatives:map<string,struct<first:string,middle:string,last:string>>>")
-    checkAnswer(query.orderBy("id"),
-      Row(2, "Jones", null, "Janet", null, null, null, null, null, null, null, "567 Maple Drive") ::
-      Row(3, "Jones", null, "Jim", null, null, null, null, null, null, null, "6242 Ash Street") ::
-      Nil)
-  }
-
-  testSchemaPruning("empty schema intersection") {
-    val query = sql("select name.middle from contacts where p=2")
-    checkScan(query, "struct<name:struct<middle:string>>")
-    checkAnswer(query.orderBy("id"),
-      Row(null) :: Row(null) :: Nil)
-  }
-
-  testSchemaPruning("select a single complex field and in where clause") {
-    val query1 = sql("select name.first from contacts where name.first = 'Jane'")
-    checkScan(query1, "struct<name:struct<first:string>>")
-    checkAnswer(query1, Row("Jane") :: Nil)
-
-    val query2 = sql("select name.first, name.last from contacts where name.first = 'Jane'")
-    checkScan(query2, "struct<name:struct<first:string,last:string>>")
-    checkAnswer(query2, Row("Jane", "Doe") :: Nil)
-
-    val query3 = sql("select name.first from contacts " +
-      "where employer.company.name = 'abc' and p = 1")
-    checkScan(query3, "struct<name:struct<first:string>," +
-      "employer:struct<company:struct<name:string>>>")
-    checkAnswer(query3, Row("Jane") :: Nil)
-
-    val query4 = sql("select name.first, employer.company.name from contacts " +
-      "where employer.company is not null and p = 1")
-    checkScan(query4, "struct<name:struct<first:string>," +
-      "employer:struct<company:struct<name:string>>>")
-    checkAnswer(query4, Row("Jane", "abc") :: Nil)
-  }
-
-  testSchemaPruning("select nullable complex field and having is not null predicate") {
-    val query = sql("select employer.company from contacts " +
-      "where employer is not null and p = 1")
-    checkScan(query, "struct<employer:struct<company:struct<name:string,address:string>>>")
-    checkAnswer(query, Row(Row("abc", "123 Business Street")) :: Row(null) :: Nil)
-  }
-
-  testSchemaPruning("select a single complex field and is null expression in project") {
-    val query = sql("select name.first, address is not null from contacts")
-    checkScan(query, "struct<name:struct<first:string>,address:string>")
-    checkAnswer(query.orderBy("id"),
-      Row("Jane", true) :: Row("John", true) :: Row("Janet", true) :: Row("Jim", true) :: Nil)
-  }
-
-  testSchemaPruning("select a single complex field array and in clause") {
-    val query = sql("select friends.middle from contacts where friends.first[0] = 'Susan'")
-    checkScan(query,
-      "struct<friends:array<struct<first:string,middle:string>>>")
-    checkAnswer(query.orderBy("id"),
-      Row(Array("Z.")) :: Nil)
-  }
-
-  testSchemaPruning("select a single complex field from a map entry and in clause") {
-    val query =
-      sql("select relatives[\"brother\"].middle from contacts " +
-        "where relatives[\"brother\"].first = 'John'")
-    checkScan(query,
-      "struct<relatives:map<string,struct<first:string,middle:string>>>")
-    checkAnswer(query.orderBy("id"),
-      Row("Y.") :: Nil)
-  }
-
-  testSchemaPruning("select one complex field and having is null predicate on another " +
-      "complex field") {
-    val query = sql("select * from contacts")
-      .where("name.middle is not null")
-      .select(
-        "id",
-        "name.first",
-        "name.middle",
-        "name.last"
-      )
-      .where("last = 'Jones'")
-      .select(count("id")).toDF()
-    checkScan(query,
-      "struct<id:int,name:struct<middle:string,last:string>>")
-    checkAnswer(query, Row(0) :: Nil)
-  }
-
-  testSchemaPruning("select one deep nested complex field and having is null predicate on " +
-      "another deep nested complex field") {
-    val query = sql("select * from contacts")
-      .where("employer.company.address is not null")
-      .selectExpr(
-        "id",
-        "name.first",
-        "name.middle",
-        "name.last",
-        "employer.id as employer_id"
-      )
-      .where("employer_id = 0")
-      .select(count("id")).toDF()
-    checkScan(query,
-      "struct<id:int,employer:struct<id:int,company:struct<address:string>>>")
-    checkAnswer(query, Row(1) :: Nil)
-  }
-
-  private def testSchemaPruning(testName: String)(testThunk: => Unit) {
-    test(s"Spark vectorized reader - without partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
-        withContacts(testThunk)
-      }
-    }
-    test(s"Spark vectorized reader - with partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "true") {
-        withContactsWithDataPartitionColumn(testThunk)
-      }
-    }
-
-    test(s"Parquet-mr reader - without partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
-        withContacts(testThunk)
-      }
-    }
-    test(s"Parquet-mr reader - with partition data column - $testName") {
-      withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
-        withContactsWithDataPartitionColumn(testThunk)
-      }
-    }
-  }
-
-  private def withContacts(testThunk: => Unit) {
-    withTempPath { dir =>
-      val path = dir.getCanonicalPath
-
-      makeParquetFile(contacts, new File(path + "/contacts/p=1"))
-      makeParquetFile(briefContacts, new File(path + "/contacts/p=2"))
-
-      spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts")
-
-      testThunk
-    }
-  }
-
-  private def withContactsWithDataPartitionColumn(testThunk: => Unit) {
-    withTempPath { dir =>
-      val path = dir.getCanonicalPath
-
-      makeParquetFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1"))
-      makeParquetFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2"))
-
-      spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts")
-
-      testThunk
-    }
-  }
-
-  case class MixedCaseColumn(a: String, B: Int)
-  case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)
-
-  private val mixedCaseData =
-    MixedCase(0, "r0c1", MixedCaseColumn("abc", 1)) ::
-    MixedCase(1, "r1c1", MixedCaseColumn("123", 2)) ::
-    Nil
-
-  testExactCaseQueryPruning("select with exact column names") {
-    val query = sql("select CoL1, coL2.B from mixedcase")
-    checkScan(query, "struct<CoL1:string,coL2:struct<B:int>>")
-    checkAnswer(query.orderBy("id"),
-      Row("r0c1", 1) ::
-      Row("r1c1", 2) ::
-      Nil)
-  }
-
-  testMixedCaseQueryPruning("select with lowercase column names") {
-    val query = sql("select col1, col2.b from mixedcase")
-    checkScan(query, "struct<CoL1:string,coL2:struct<B:int>>")
-    checkAnswer(query.orderBy("id"),
-      Row("r0c1", 1) ::
-      Row("r1c1", 2) ::
-      Nil)
-  }
-
-  testMixedCaseQueryPruning("select with different-case column names") {
-    val query = sql("select cOL1, cOl2.b from mixedcase")
-    checkScan(query, "struct<CoL1:string,coL2:struct<B:int>>")
-    checkAnswer(query.orderBy("id"),
-      Row("r0c1", 1) ::
-      Row("r1c1", 2) ::
-      Nil)
-  }
-
-  testMixedCaseQueryPruning("filter with different-case column names") {
-    val query = sql("select id from mixedcase where Col2.b = 2")
-    checkScan(query, "struct<id:int,coL2:struct<B:int>>")
-    checkAnswer(query.orderBy("id"), Row(1) :: Nil)
-  }
-
-  // Tests schema pruning for a query whose column and field names are exactly the same as the table
-  // schema's column and field names. N.B. this implies that `testThunk` should pass using either a
-  // case-sensitive or case-insensitive query parser
-  private def testExactCaseQueryPruning(testName: String)(testThunk: => Unit) {
-    test(s"Case-sensitive parser - mixed-case schema - $testName") {
-      withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
-        withMixedCaseData(testThunk)
-      }
-    }
-    testMixedCaseQueryPruning(testName)(testThunk)
-  }
-
-  // Tests schema pruning for a query whose column and field names may differ in case from the table
-  // schema's column and field names
-  private def testMixedCaseQueryPruning(testName: String)(testThunk: => Unit) {
-    test(s"Case-insensitive parser - mixed-case schema - $testName") {
-      withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
-        withMixedCaseData(testThunk)
-      }
-    }
-  }
-
-  // Tests given test function with Spark vectorized reader and Parquet-mr reader.
-  private def withMixedCaseData(testThunk: => Unit) {
-    withParquetTable(mixedCaseData, "mixedcase") {
-      testThunk
-    }
-  }
-
-  private val schemaEquality = new Equality[StructType] {
-    override def areEqual(a: StructType, b: Any): Boolean =
-      b match {
-        case otherType: StructType => a.sameType(otherType)
-        case _ => false
-      }
-  }
-
-  protected def checkScan(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
-    checkScanSchemata(df, expectedSchemaCatalogStrings: _*)
-    // We check here that we can execute the query without throwing an exception. The results
-    // themselves are irrelevant, and should be checked elsewhere as needed
-    df.collect()
-  }
 
-  private def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
-    val fileSourceScanSchemata =
-      df.queryExecution.executedPlan.collect {
-        case scan: FileSourceScanExec => scan.requiredSchema
-      }
-    assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
-      s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
-        s"but expected $expectedSchemaCatalogStrings")
-    fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach {
-      case (scanSchema, expectedScanSchemaCatalogString) =>
-        val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString)
-        implicit val equality = schemaEquality
-        assert(scanSchema === expectedScanSchema)
-    }
-  }
+class ParquetSchemaPruningSuite extends SchemaPruningSuite {
+  override protected val dataSourceName: String = "parquet"
+  override protected val vectorizedReaderEnabledKey: String =
+    SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org