You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2020/07/02 05:00:49 UTC

[spark] branch branch-3.0 updated: [SPARK-32136][SQL] NormalizeFloatingNumbers should work on null struct

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 7f11c8f  [SPARK-32136][SQL] NormalizeFloatingNumbers should work on null struct
7f11c8f is described below

commit 7f11c8f05478391534f871f7c70f13391b5c69ba
Author: Liang-Chi Hsieh <vi...@gmail.com>
AuthorDate: Thu Jul 2 13:56:43 2020 +0900

    [SPARK-32136][SQL] NormalizeFloatingNumbers should work on null struct
    
    ### What changes were proposed in this pull request?
    
    This patch fixes wrong groupBy result if the grouping key is a null-value struct.
    
    ### Why are the changes needed?
    
    `NormalizeFloatingNumbers` reconstructs a struct if input expression is StructType. If the input struct is null, it will reconstruct a struct with null-value fields, instead of null.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, fixing incorrect groupBy result.
    
    ### How was this patch tested?
    
    Unit test.
    
    Closes #28962 from viirya/SPARK-32136.
    
    Authored-by: Liang-Chi Hsieh <vi...@gmail.com>
    Signed-off-by: HyukjinKwon <gu...@apache.org>
    (cherry picked from commit 3f7780d30d712e6d3894bacb5e80113c7a4bcc09)
    Signed-off-by: HyukjinKwon <gu...@apache.org>
---
 .../sql/catalyst/optimizer/NormalizeFloatingNumbers.scala    |  5 +++--
 .../scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++++++
 2 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index 4373820..8d5dbc7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
@@ -123,7 +123,8 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
       val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
         normalize(GetStructField(expr, i))
       }
-      CreateStruct(fields)
+      val struct = CreateStruct(fields)
+      KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
 
     case _ if expr.dataType.isInstanceOf[ArrayType] =>
       val ArrayType(et, containsNull) = expr.dataType
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index f7438f3..09f30bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1028,4 +1028,16 @@ class DataFrameAggregateSuite extends QueryTest
       checkAnswer(df, Row("abellina", 2) :: Row("mithunr", 1) :: Nil)
     }
   }
+
+  test("SPARK-32136: NormalizeFloatingNumbers should work on null struct") {
+    val df = Seq(
+      A(None),
+      A(Some(B(None))),
+      A(Some(B(Some(1.0))))).toDF
+    val groupBy = df.groupBy("b").agg(count("*"))
+    checkAnswer(groupBy, Row(null, 1) :: Row(Row(null), 1) :: Row(Row(1.0), 1) :: Nil)
+  }
 }
+
+case class B(c: Option[Double])
+case class A(b: Option[B])


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org