You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2018/12/08 19:18:21 UTC
spark git commit: [SPARK-26021][SQL][FOLLOWUP] only deal with NaN and
-0.0 in UnsafeWriter
Repository: spark
Updated Branches:
refs/heads/master 678e1aca6 -> bdf32847b
[SPARK-26021][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter
## What changes were proposed in this pull request?
A followup of https://github.com/apache/spark/pull/23043
There are 4 places we need to deal with NaN and -0.0:
1. comparison expressions. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.
2. Join keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.
3. grouping keys. `-0.0` and `0.0` should be assigned to the same group. Different NaNs should be assigned to the same group.
4. window partition keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.
The case 1 is OK. Our comparison already handles NaN and -0.0, and for struct/array/map, we will recursively compare the fields/elements.
Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
To fix it, a simple solution is: normalize float/double when building unsafe data (`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`). Then we don't need to worry about it anymore.
Following this direction, this PR moves the handling of NaN and -0.0 from `Platform` to `UnsafeWriter`, so that places like `UnsafeRow.setFloat` will not handle them, which reduces the perf overhead. It's also easier to add comments explaining why we do it in `UnsafeWriter`.
## How was this patch tested?
existing tests
Closes #23239 from cloud-fan/minor.
Authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bdf32847
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bdf32847
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bdf32847
Branch: refs/heads/master
Commit: bdf32847b1ffcb3aa4d0bef058f86e65656e99fb
Parents: 678e1ac
Author: Wenchen Fan <we...@databricks.com>
Authored: Sat Dec 8 11:18:09 2018 -0800
Committer: Dongjoon Hyun <do...@apache.org>
Committed: Sat Dec 8 11:18:09 2018 -0800
----------------------------------------------------------------------
.../java/org/apache/spark/unsafe/Platform.java | 10 ------
.../apache/spark/unsafe/PlatformUtilSuite.java | 18 ----------
.../expressions/codegen/UnsafeWriter.java | 35 ++++++++++++++++++++
.../codegen/UnsafeRowWriterSuite.scala | 20 +++++++++++
.../apache/spark/sql/DataFrameJoinSuite.scala | 12 +++++++
.../sql/DataFrameWindowFunctionsSuite.scala | 14 ++++++++
6 files changed, 81 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index 4563efc..076b693 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -174,11 +174,6 @@ public final class Platform {
}
public static void putFloat(Object object, long offset, float value) {
- if (Float.isNaN(value)) {
- value = Float.NaN;
- } else if (value == -0.0f) {
- value = 0.0f;
- }
_UNSAFE.putFloat(object, offset, value);
}
@@ -187,11 +182,6 @@ public final class Platform {
}
public static void putDouble(Object object, long offset, double value) {
- if (Double.isNaN(value)) {
- value = Double.NaN;
- } else if (value == -0.0d) {
- value = 0.0d;
- }
_UNSAFE.putDouble(object, offset, value);
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index 2474081..3ad9ac7 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -157,22 +157,4 @@ public class PlatformUtilSuite {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}
-
- @Test
- // SPARK-26021
- public void writeMinusZeroIsReplacedWithZero() {
- byte[] doubleBytes = new byte[Double.BYTES];
- byte[] floatBytes = new byte[Float.BYTES];
- Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
- Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
-
- byte[] doubleBytes2 = new byte[Double.BYTES];
- byte[] floatBytes2 = new byte[Float.BYTES];
- Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, 0.0d);
- Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, 0.0f);
-
- // Make sure the bytes we write from 0.0 and -0.0 are same.
- Assert.assertArrayEquals(doubleBytes, doubleBytes2);
- Assert.assertArrayEquals(floatBytes, floatBytes2);
- }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 95263a0..7553ab8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -198,11 +198,46 @@ public abstract class UnsafeWriter {
Platform.putLong(getBuffer(), offset, value);
}
+ // We need to take care of NaN and -0.0 in several places:
+ // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
+ // treated as same.
+ // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
+ // to the same group.
+ // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
+ // treated as same.
+ // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
+ // should be treated as same.
+ //
+ // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
+ // recursively compare the fields/elements, so it's also fine.
+ //
+ // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
+ // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
+ //
+ // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
+ // float/double columns and nested fields to `UnsafeRow`.
+ //
+ // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
+ // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
+ // types, so nested float/double may not be normalized. We need to make sure that all the unsafe
+ // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
+ // creation.
protected final void writeFloat(long offset, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ } else if (value == -0.0f) {
+ value = 0.0f;
+ }
Platform.putFloat(getBuffer(), offset, value);
}
+ // See comments for `writeFloat`.
protected final void writeDouble(long offset, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ } else if (value == -0.0d) {
+ value = 0.0d;
+ }
Platform.putDouble(getBuffer(), offset, value);
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
index fb651b7..22e1fa6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
@@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
assert(res1 == res2)
}
+ test("SPARK-26021: normalize float/double NaN and -0.0") {
+ val unsafeRowWriter1 = new UnsafeRowWriter(4)
+ unsafeRowWriter1.resetRowWriter()
+ unsafeRowWriter1.write(0, Float.NaN)
+ unsafeRowWriter1.write(1, Double.NaN)
+ unsafeRowWriter1.write(2, 0.0f)
+ unsafeRowWriter1.write(3, 0.0)
+ val res1 = unsafeRowWriter1.getRow
+
+ val unsafeRowWriter2 = new UnsafeRowWriter(4)
+ unsafeRowWriter2.resetRowWriter()
+ unsafeRowWriter2.write(0, 0.0f/0.0f)
+ unsafeRowWriter2.write(1, 0.0/0.0)
+ unsafeRowWriter2.write(2, -0.0f)
+ unsafeRowWriter2.write(3, -0.0)
+ val res2 = unsafeRowWriter2.getRow
+
+ // The two rows should be the equal
+ assert(res1 == res2)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index e6b30f9..c9f41ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
}
}
+
+ test("NaN and -0.0 in join keys") {
+ val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
+ val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
+ val joined = df1.join(df2, Seq("f", "d"))
+ checkAnswer(joined, Seq(
+ Row(Float.NaN, Double.NaN),
+ Row(0.0f, 0.0),
+ Row(0.0f, 0.0),
+ Row(0.0f, 0.0),
+ Row(0.0f, 0.0)))
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/bdf32847/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 78277d7..9a5d5a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -681,4 +681,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
Row("S2", "P2", 300, 300, 500)))
}
+
+ test("NaN and -0.0 in window partition keys") {
+ val df = Seq(
+ (Float.NaN, Double.NaN, 1),
+ (0.0f/0.0f, 0.0/0.0, 1),
+ (0.0f, 0.0, 1),
+ (-0.0f, -0.0, 1)).toDF("f", "d", "i")
+ val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
+ checkAnswer(result, Seq(
+ Row(Float.NaN, 2),
+ Row(Float.NaN, 2),
+ Row(0.0f, 2),
+ Row(0.0f, 2)))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org