You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/02/22 13:12:56 UTC

[spark] branch branch-3.1 updated: [SPARK-34473][SQL] Avoid NPE in DataFrameReader.schema(StructType)

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 153a490  [SPARK-34473][SQL] Avoid NPE in DataFrameReader.schema(StructType)
153a490 is described below

commit 153a490b7c599f5a315e4521c7689cb7004f809b
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Mon Feb 22 21:11:21 2021 +0800

    [SPARK-34473][SQL] Avoid NPE in DataFrameReader.schema(StructType)
    
    ### What changes were proposed in this pull request?
    
    This fixes a regression in `DataFrameReader.schema(StructType)`, to avoid NPE if the given `StructType` is null. Note that, passing null to Spark public APIs leads to undefined behavior. There is no document mentioning the null behavior, and it's just an accident that `DataFrameReader.schema(StructType)` worked before. So I think this is not a 3.1 blocker.
    
    ### Why are the changes needed?
    
    It fixes a 3.1 regression
    
    ### Does this PR introduce _any_ user-facing change?
    
    yea, now `df.read.schema(null: StructType)` is a noop as before, while in the current branch-3.1 it throws NPE.
    
    ### How was this patch tested?
    
    It's undefined behavior and is very obvious, so I didn't add a test. We should add tests when we clearly define and fix the null behavior for all public APIs.
    
    Closes #31593 from cloud-fan/minor.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 02c784ca686fc675b63ce37f03215bc6c2fec869)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 11 +++++------
 .../org/apache/spark/sql/streaming/DataStreamReader.scala     | 11 +++++------
 2 files changed, 10 insertions(+), 12 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index b94c42a..e4da076 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -73,8 +73,10 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
    * @since 1.4.0
    */
   def schema(schema: StructType): DataFrameReader = {
-    val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
-    this.userSpecifiedSchema = Option(replaced)
+    if (schema != null) {
+      val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+      this.userSpecifiedSchema = Option(replaced)
+    }
     this
   }
 
@@ -90,10 +92,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
    * @since 2.3.0
    */
   def schema(schemaString: String): DataFrameReader = {
-    val rawSchema = StructType.fromDDL(schemaString)
-    val schema = CharVarcharUtils.failIfHasCharVarchar(rawSchema).asInstanceOf[StructType]
-    this.userSpecifiedSchema = Option(schema)
-    this
+    schema(StructType.fromDDL(schemaString))
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index d82fa9e..06c7579 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -64,8 +64,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
    * @since 2.0.0
    */
   def schema(schema: StructType): DataStreamReader = {
-    val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
-    this.userSpecifiedSchema = Option(replaced)
+    if (schema != null) {
+      val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType]
+      this.userSpecifiedSchema = Option(replaced)
+    }
     this
   }
 
@@ -77,10 +79,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
    * @since 2.3.0
    */
   def schema(schemaString: String): DataStreamReader = {
-    val rawSchema = StructType.fromDDL(schemaString)
-    val schema = CharVarcharUtils.failIfHasCharVarchar(rawSchema).asInstanceOf[StructType]
-    this.userSpecifiedSchema = Option(schema)
-    this
+    schema(StructType.fromDDL(schemaString))
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org