You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/04/21 07:19:46 UTC

[spark] branch branch-3.0 updated: [SPARK-35096][SQL] SchemaPruning should adhere spark.sql.caseSensitive config

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new aef17b7  [SPARK-35096][SQL] SchemaPruning should adhere spark.sql.caseSensitive config
aef17b7 is described below

commit aef17b740f48fe559520e1c958f7e22a4a00f698
Author: sandeep.katta <sa...@gmail.com>
AuthorDate: Wed Apr 21 15:16:17 2021 +0800

    [SPARK-35096][SQL] SchemaPruning should adhere spark.sql.caseSensitive config
    
    ### What changes were proposed in this pull request?
    
    As a part of the SPARK-26837 pruning of nested fields from object serializers are supported. But it is missed to handle case insensitivity nature of spark
    
    In this PR I have resolved the column names to be pruned based on `spark.sql.caseSensitive ` config
    **Exception Before Fix**
    
    ```
    Caused by: java.lang.ArrayIndexOutOfBoundsException: 0
      at org.apache.spark.sql.types.StructType.apply(StructType.scala:414)
      at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.$anonfun$applyOrElse$3(objects.scala:216)
      at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
      at scala.collection.immutable.List.foreach(List.scala:392)
      at scala.collection.TraversableLike.map(TraversableLike.scala:238)
      at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
      at scala.collection.immutable.List.map(List.scala:298)
      at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.applyOrElse(objects.scala:215)
      at org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning$$anonfun$apply$4.applyOrElse(objects.scala:203)
      at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$1(TreeNode.scala:309)
      at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72)
      at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:309)
      at
    ```
    
    ### Why are the changes needed?
    After Upgrade to Spark 3 `foreachBatch` API throws` java.lang.ArrayIndexOutOfBoundsException`. This issue will be fixed using this PR
    
    ### Does this PR introduce _any_ user-facing change?
    No, Infact fixes the regression
    
    ### How was this patch tested?
    Added tests and also tested verified manually
    
    Closes #32194 from sandeep-katta/SPARK-35096.
    
    Authored-by: sandeep.katta <sa...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/expressions/SchemaPruning.scala   | 15 +++++---
 .../catalyst/expressions/SchemaPruningSuite.scala  | 43 +++++++++++++++++++++-
 2 files changed, 52 insertions(+), 6 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index 6213267..4ee6488 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.types._
 
-object SchemaPruning {
+object SchemaPruning extends SQLConfHelper {
   /**
    * Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
    * and given requested field are "a", the field "b" is pruned in the returned schema.
@@ -28,6 +29,7 @@ object SchemaPruning {
   def pruneDataSchema(
       dataSchema: StructType,
       requestedRootFields: Seq[RootField]): StructType = {
+    val resolver = conf.resolver
     // Merge the requested root fields into a single schema. Note the ordering of the fields
     // in the resulting schema may differ from their ordering in the logical relation's
     // original schema
@@ -36,7 +38,7 @@ object SchemaPruning {
       .reduceLeft(_ merge _)
     val dataSchemaFieldNames = dataSchema.fieldNames.toSet
     val mergedDataSchema =
-      StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
+      StructType(mergedSchema.filter(f => dataSchemaFieldNames.exists(resolver(_, f.name))))
     // Sort the fields of mergedDataSchema according to their order in dataSchema,
     // recursively. This makes mergedDataSchema a pruned schema of dataSchema
     sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
@@ -61,12 +63,15 @@ object SchemaPruning {
           sortLeftFieldsByRight(leftValueType, rightValueType),
           containsNull)
       case (leftStruct: StructType, rightStruct: StructType) =>
-        val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
+        val resolver = conf.resolver
+        val filteredRightFieldNames = rightStruct.fieldNames
+          .filter(name => leftStruct.fieldNames.exists(resolver(_, name)))
         val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
-          val leftFieldType = leftStruct(fieldName).dataType
+          val resolvedLeftStruct = leftStruct.find(p => resolver(p.name, fieldName)).get
+          val leftFieldType = resolvedLeftStruct.dataType
           val rightFieldType = rightStruct(fieldName).dataType
           val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
-          StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable)
+          StructField(fieldName, sortedLeftFieldType, nullable = resolvedLeftStruct.nullable)
         }
         StructType(sortedLeftFields)
       case _ => left
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
index c04f59e..7895f4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruningSuite.scala
@@ -18,9 +18,20 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.SchemaPruning.RootField
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.internal.SQLConf.CASE_SENSITIVE
 import org.apache.spark.sql.types._
 
-class SchemaPruningSuite extends SparkFunSuite {
+class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
+
+  def getRootFields(requestedFields: StructField*): Seq[RootField] = {
+    requestedFields.map { f =>
+      // `derivedFromAtt` doesn't affect the result of pruned schema.
+      SchemaPruning.RootField(field = f, derivedFromAtt = true)
+    }
+  }
+
   test("prune schema by the requested fields") {
     def testPrunedSchema(
         schema: StructType,
@@ -59,4 +70,34 @@ class SchemaPruningSuite extends SparkFunSuite {
       StructType.fromDDL("e int, f string")))
     testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
   }
+
+  test("SPARK-35096: test case insensitivity of pruned schema") {
+    Seq(true, false).foreach(isCaseSensitive => {
+      withSQLConf(CASE_SENSITIVE.key -> isCaseSensitive.toString) {
+        if (isCaseSensitive) {
+          // Schema is case-sensitive
+          val requestedFields = getRootFields(StructField("id", IntegerType))
+          val prunedSchema = SchemaPruning.pruneDataSchema(
+            StructType.fromDDL("ID int, name String"), requestedFields)
+          assert(prunedSchema == StructType(Seq.empty))
+          // Root fields are case-sensitive
+          val rootFieldsSchema = SchemaPruning.pruneDataSchema(
+            StructType.fromDDL("id int, name String"),
+            getRootFields(StructField("ID", IntegerType)))
+          assert(rootFieldsSchema == StructType(StructType(Seq.empty)))
+        } else {
+          // Schema is case-insensitive
+          val prunedSchema = SchemaPruning.pruneDataSchema(
+            StructType.fromDDL("ID int, name String"),
+            getRootFields(StructField("id", IntegerType)))
+          assert(prunedSchema == StructType(StructField("ID", IntegerType) :: Nil))
+          // Root fields are case-insensitive
+          val rootFieldsSchema = SchemaPruning.pruneDataSchema(
+            StructType.fromDDL("id int, name String"),
+            getRootFields(StructField("ID", IntegerType)))
+          assert(rootFieldsSchema == StructType(StructField("id", IntegerType) :: Nil))
+        }
+      }
+    })
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org