You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by vi...@apache.org on 2020/09/18 05:14:45 UTC
[spark] branch branch-3.0 updated: [SPARK-32906][SQL] Struct field
names should not change after normalizing floats
This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 2d55de5 [SPARK-32906][SQL] Struct field names should not change after normalizing floats
2d55de5 is described below
commit 2d55de5db06ff522a8e7fded3760c3f65c8bbf93
Author: Takeshi Yamamuro <ya...@apache.org>
AuthorDate: Thu Sep 17 22:07:47 2020 -0700
[SPARK-32906][SQL] Struct field names should not change after normalizing floats
### What changes were proposed in this pull request?
This PR intends to fix a minor bug when normalizing floats for struct types;
```
scala> import org.apache.spark.sql.execution.aggregate.HashAggregateExec
scala> val df = Seq(Tuple1(Tuple1(-0.0d)), Tuple1(Tuple1(0.0d))).toDF("k")
scala> val agg = df.distinct()
scala> agg.explain()
== Physical Plan ==
*(2) HashAggregate(keys=[k#40], functions=[])
+- Exchange hashpartitioning(k#40, 200), true, [id=#62]
+- *(1) HashAggregate(keys=[knownfloatingpointnormalized(if (isnull(k#40)) null else named_struct(col1, knownfloatingpointnormalized(normalizenanandzero(k#40._1)))) AS k#40], functions=[])
+- *(1) LocalTableScan [k#40]
scala> val aggOutput = agg.queryExecution.sparkPlan.collect { case a: HashAggregateExec => a.output.head }
scala> aggOutput.foreach { attr => println(attr.prettyJson) }
### Final Aggregate ###
[ {
"class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
"num-children" : 0,
"name" : "k",
"dataType" : {
"type" : "struct",
"fields" : [ {
"name" : "_1",
^^^
"type" : "double",
"nullable" : false,
"metadata" : { }
} ]
},
"nullable" : true,
"metadata" : { },
"exprId" : {
"product-class" : "org.apache.spark.sql.catalyst.expressions.ExprId",
"id" : 40,
"jvmId" : "a824e83f-933e-4b85-a1ff-577b5a0e2366"
},
"qualifier" : [ ]
} ]
### Partial Aggregate ###
[ {
"class" : "org.apache.spark.sql.catalyst.expressions.AttributeReference",
"num-children" : 0,
"name" : "k",
"dataType" : {
"type" : "struct",
"fields" : [ {
"name" : "col1",
^^^^
"type" : "double",
"nullable" : true,
"metadata" : { }
} ]
},
"nullable" : true,
"metadata" : { },
"exprId" : {
"product-class" : "org.apache.spark.sql.catalyst.expressions.ExprId",
"id" : 40,
"jvmId" : "a824e83f-933e-4b85-a1ff-577b5a0e2366"
},
"qualifier" : [ ]
} ]
```
### Why are the changes needed?
bugfix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added tests.
Closes #29780 from maropu/FixBugInNormalizedFloatingNumbers.
Authored-by: Takeshi Yamamuro <ya...@apache.org>
Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
(cherry picked from commit b49aaa33e13814a448be51a7e65a29cb515b8248)
Signed-off-by: Liang-Chi Hsieh <vi...@gmail.com>
---
.../spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala | 6 +++---
.../test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala | 8 ++++++++
2 files changed, 11 insertions(+), 3 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index 8d5dbc7..f0cf671 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -120,10 +120,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
case _ if expr.dataType.isInstanceOf[StructType] =>
- val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
- normalize(GetStructField(expr, i))
+ val fields = expr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map {
+ case (name, i) => Seq(Literal(name), normalize(GetStructField(expr, i)))
}
- val struct = CreateStruct(fields)
+ val struct = CreateNamedStruct(fields.flatten.toSeq)
KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct))
case _ if expr.dataType.isInstanceOf[ArrayType] =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 54327b3..2cb7790 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -1044,6 +1044,14 @@ class DataFrameAggregateSuite extends QueryTest
checkAnswer(sql(queryTemplate("FIRST")), Row(1))
checkAnswer(sql(queryTemplate("LAST")), Row(3))
}
+
+ test("SPARK-32906: struct field names should not change after normalizing floats") {
+ val df = Seq(Tuple1(Tuple2(-0.0d, Double.NaN)), Tuple1(Tuple2(0.0d, Double.NaN))).toDF("k")
+ val aggs = df.distinct().queryExecution.sparkPlan.collect { case a: HashAggregateExec => a }
+ assert(aggs.length == 2)
+ assert(aggs.head.output.map(_.dataType.simpleString).head ===
+ aggs.last.output.map(_.dataType.simpleString).head)
+ }
}
case class B(c: Option[Double])
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org