You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ge...@apache.org on 2022/03/02 02:56:44 UTC

[spark] branch branch-3.2 updated: [SPARK-38363][SQL] Avoid runtime error in Dataset.summary()/Dataset.describe() when ANSI mode is on

This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 78c7c14  [SPARK-38363][SQL] Avoid runtime error in Dataset.summary()/Dataset.describe() when ANSI mode is on
78c7c14 is described below

commit 78c7c14c7f8e4e5c343d68dffc41801430a45699
Author: Gengliang Wang <ge...@apache.org>
AuthorDate: Wed Mar 2 10:54:03 2022 +0800

    [SPARK-38363][SQL] Avoid runtime error in Dataset.summary()/Dataset.describe() when ANSI mode is on
    
    ### What changes were proposed in this pull request?
    
    When executing `df.summary()` or `df.describe()`, Spark SQL converts String columns as Double for the
     percentiles/mean/stddev stats.
    ```
    scala> val person2: DataFrame = Seq(
         |     ("Bob", 16, 176),
         |     ("Alice", 32, 164),
         |     ("David", 60, 192),
         |     ("Amy", 24, 180)).toDF("name", "age", "height")
    
    scala> person2.summary().show()
    +-------+-----+------------------+------------------+
    |summary| name|               age|            height|
    +-------+-----+------------------+------------------+
    |  count|    4|                 4|                 4|
    |   mean| null|              33.0|             178.0|
    | stddev| null|19.148542155126762|11.547005383792515|
    |    min|Alice|                16|               164|
    |    25%| null|                16|               164|
    |    50%| null|                24|               176|
    |    75%| null|                32|               180|
    |    max|David|                60|               192|
    +-------+-----+------------------+------------------+
    ```
    
    This can cause runtime errors with ANSI mode on.
    ```
    org.apache.spark.SparkNumberFormatException: invalid input syntax for type numeric: Bob
    ```
    This PR is to fix it by using `TryCast` for String columns.
    
    ### Why are the changes needed?
    
    For better adoption of the ANSI mode. Since both APIs are for getting a quick summary of the Dataframe, I suggest using `TryCast` for the problematic stats so that both APIs still work under ANSI mode.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    UT
    
    Closes #35699 from gengliangwang/fixSummary.
    
    Authored-by: Gengliang Wang <ge...@apache.org>
    Signed-off-by: Gengliang Wang <ge...@apache.org>
    (cherry picked from commit 80f25ad24a871f0ddef939f6a3e2f01370f1fa6f)
    Signed-off-by: Gengliang Wang <ge...@apache.org>
---
 .../spark/sql/execution/stat/StatFunctions.scala   | 15 +++-
 .../org/apache/spark/sql/DataFrameSuite.scala      | 86 ++++++++++++----------
 2 files changed, 58 insertions(+), 43 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index 5dc0ff0..9155c1c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -22,7 +22,7 @@ import java.util.Locale
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, Expression, GenericInternalRow, GetArrayItem, Literal, TryCast}
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries}
@@ -246,6 +246,11 @@ object StatFunctions extends Logging {
     }
     require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")
 
+    def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) {
+      TryCast(e, DoubleType)
+    } else {
+      e
+    }
     var percentileIndex = 0
     val statisticFns = selectedStatistics.map { stats =>
       if (stats.endsWith("%")) {
@@ -253,7 +258,7 @@ object StatFunctions extends Logging {
         percentileIndex += 1
         (child: Expression) =>
           GetArrayItem(
-            new ApproximatePercentile(child,
+            new ApproximatePercentile(castAsDoubleIfNecessary(child),
               Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false)))
               .toAggregateExpression(),
             Literal(index))
@@ -264,8 +269,10 @@ object StatFunctions extends Logging {
             Count(child).toAggregateExpression(isDistinct = true)
           case "approx_count_distinct" => (child: Expression) =>
             HyperLogLogPlusPlus(child).toAggregateExpression()
-          case "mean" => (child: Expression) => Average(child).toAggregateExpression()
-          case "stddev" => (child: Expression) => StddevSamp(child).toAggregateExpression()
+          case "mean" => (child: Expression) =>
+            Average(castAsDoubleIfNecessary(child)).toAggregateExpression()
+          case "stddev" => (child: Expression) =>
+            StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression()
           case "min" => (child: Expression) => Min(child).toAggregateExpression()
           case "max" => (child: Expression) => Max(child).toAggregateExpression()
           case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 70ec052..e427c43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -866,29 +866,33 @@ class DataFrameSuite extends QueryTest
 
     def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
 
-    val describeAllCols = person2.describe()
-    assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
-    checkAnswer(describeAllCols, describeResult)
-    // All aggregate value should have been cast to string
-    describeAllCols.collect().foreach { row =>
-      row.toSeq.foreach { value =>
-        if (value != null) {
-          assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+    Seq("true", "false").foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) {
+        val describeAllCols = person2.describe()
+        assert(getSchemaAsSeq(describeAllCols) === Seq("summary", "name", "age", "height"))
+        checkAnswer(describeAllCols, describeResult)
+        // All aggregate value should have been cast to string
+        describeAllCols.collect().foreach { row =>
+          row.toSeq.foreach { value =>
+            if (value != null) {
+              assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+            }
+          }
         }
-      }
-    }
 
-    val describeOneCol = person2.describe("age")
-    assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
-    checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d)} )
+        val describeOneCol = person2.describe("age")
+        assert(getSchemaAsSeq(describeOneCol) === Seq("summary", "age"))
+        checkAnswer(describeOneCol, describeResult.map { case Row(s, _, d, _) => Row(s, d) })
 
-    val describeNoCol = person2.select().describe()
-    assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
-    checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s)} )
+        val describeNoCol = person2.select().describe()
+        assert(getSchemaAsSeq(describeNoCol) === Seq("summary"))
+        checkAnswer(describeNoCol, describeResult.map { case Row(s, _, _, _) => Row(s) })
 
-    val emptyDescription = person2.limit(0).describe()
-    assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
-    checkAnswer(emptyDescription, emptyDescribeResult)
+        val emptyDescription = person2.limit(0).describe()
+        assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
+        checkAnswer(emptyDescription, emptyDescribeResult)
+      }
+    }
   }
 
   test("summary") {
@@ -914,30 +918,34 @@ class DataFrameSuite extends QueryTest
 
     def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
 
-    val summaryAllCols = person2.summary()
-
-    assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height"))
-    checkAnswer(summaryAllCols, summaryResult)
-    // All aggregate value should have been cast to string
-    summaryAllCols.collect().foreach { row =>
-      row.toSeq.foreach { value =>
-        if (value != null) {
-          assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+    Seq("true", "false").foreach { ansiEnabled =>
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled) {
+        val summaryAllCols = person2.summary()
+
+        assert(getSchemaAsSeq(summaryAllCols) === Seq("summary", "name", "age", "height"))
+        checkAnswer(summaryAllCols, summaryResult)
+        // All aggregate value should have been cast to string
+        summaryAllCols.collect().foreach { row =>
+          row.toSeq.foreach { value =>
+            if (value != null) {
+              assert(value.isInstanceOf[String], "expected string but found " + value.getClass)
+            }
+          }
         }
-      }
-    }
 
-    val summaryOneCol = person2.select("age").summary()
-    assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age"))
-    checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d)} )
+        val summaryOneCol = person2.select("age").summary()
+        assert(getSchemaAsSeq(summaryOneCol) === Seq("summary", "age"))
+        checkAnswer(summaryOneCol, summaryResult.map { case Row(s, _, d, _) => Row(s, d) })
 
-    val summaryNoCol = person2.select().summary()
-    assert(getSchemaAsSeq(summaryNoCol) === Seq("summary"))
-    checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s)} )
+        val summaryNoCol = person2.select().summary()
+        assert(getSchemaAsSeq(summaryNoCol) === Seq("summary"))
+        checkAnswer(summaryNoCol, summaryResult.map { case Row(s, _, _, _) => Row(s) })
 
-    val emptyDescription = person2.limit(0).summary()
-    assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
-    checkAnswer(emptyDescription, emptySummaryResult)
+        val emptyDescription = person2.limit(0).summary()
+        assert(getSchemaAsSeq(emptyDescription) === Seq("summary", "name", "age", "height"))
+        checkAnswer(emptyDescription, emptySummaryResult)
+      }
+    }
   }
 
   test("SPARK-34165: Add count_distinct to summary") {

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org