You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2018/12/09 18:50:48 UTC

spark git commit: [SPARK-26021][2.4][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter

Repository: spark
Updated Branches:
  refs/heads/branch-2.4 a073b1c69 -> 33460c58a


[SPARK-26021][2.4][SQL][FOLLOWUP] only deal with NaN and -0.0 in UnsafeWriter

backport https://github.com/apache/spark/pull/23239 to 2.4

---------

## What changes were proposed in this pull request?

A followup of https://github.com/apache/spark/pull/23043

There are 4 places we need to deal with NaN and -0.0:
1. comparison expressions. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.
2. Join keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.
3. grouping keys. `-0.0` and `0.0` should be assigned to the same group. Different NaNs should be assigned to the same group.
4. window partition keys. `-0.0` and `0.0` should be treated as same. Different NaNs should be treated as same.

The case 1 is OK. Our comparison already handles NaN and -0.0, and for struct/array/map, we will recursively compare the fields/elements.

Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.

To fix it, a simple solution is: normalize float/double when building unsafe data (`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`). Then we don't need to worry about it anymore.

Following this direction, this PR moves the handling of NaN and -0.0 from `Platform` to `UnsafeWriter`, so that places like `UnsafeRow.setFloat` will not handle them, which reduces the perf overhead. It's also easier to add comments explaining why we do it in `UnsafeWriter`.

## How was this patch tested?

existing tests

Closes #23265 from cloud-fan/minor.

Authored-by: Wenchen Fan <we...@databricks.com>
Signed-off-by: Dongjoon Hyun <do...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/33460c58
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/33460c58
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/33460c58

Branch: refs/heads/branch-2.4
Commit: 33460c58a9274e22bd662858c71292275ae4aa24
Parents: a073b1c
Author: Wenchen Fan <we...@databricks.com>
Authored: Sun Dec 9 10:50:41 2018 -0800
Committer: Dongjoon Hyun <do...@apache.org>
Committed: Sun Dec 9 10:50:41 2018 -0800

----------------------------------------------------------------------
 .../java/org/apache/spark/unsafe/Platform.java  | 10 ------
 .../apache/spark/unsafe/PlatformUtilSuite.java  | 14 --------
 .../expressions/codegen/UnsafeWriter.java       | 35 ++++++++++++++++++++
 .../codegen/UnsafeRowWriterSuite.scala          | 20 +++++++++++
 .../apache/spark/sql/DataFrameJoinSuite.scala   | 12 +++++++
 .../sql/DataFrameWindowFunctionsSuite.scala     | 14 ++++++++
 6 files changed, 81 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/33460c58/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
index bc94f21..aca6fca 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
@@ -120,11 +120,6 @@ public final class Platform {
   }
 
   public static void putFloat(Object object, long offset, float value) {
-    if (Float.isNaN(value)) {
-      value = Float.NaN;
-    } else if (value == -0.0f) {
-      value = 0.0f;
-    }
     _UNSAFE.putFloat(object, offset, value);
   }
 
@@ -133,11 +128,6 @@ public final class Platform {
   }
 
   public static void putDouble(Object object, long offset, double value) {
-    if (Double.isNaN(value)) {
-      value = Double.NaN;
-    } else if (value == -0.0d) {
-      value = 0.0d;
-    }
     _UNSAFE.putDouble(object, offset, value);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/33460c58/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
index ab34324..3ad9ac7 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java
@@ -157,18 +157,4 @@ public class PlatformUtilSuite {
     Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
     Assert.assertEquals(obj3, onheap4.getBaseObject());
   }
-
-  @Test
-  // SPARK-26021
-  public void writeMinusZeroIsReplacedWithZero() {
-    byte[] doubleBytes = new byte[Double.BYTES];
-    byte[] floatBytes = new byte[Float.BYTES];
-    Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
-    Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
-    double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET);
-    float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET);
-
-    Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform));
-    Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform));
-  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/33460c58/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 95263a0..7553ab8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -198,11 +198,46 @@ public abstract class UnsafeWriter {
     Platform.putLong(getBuffer(), offset, value);
   }
 
+  // We need to take care of NaN and -0.0 in several places:
+  //   1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
+  //      treated as same.
+  //   2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
+  //      to the same group.
+  //   3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
+  //      treated as same.
+  //   4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
+  //      should be treated as same.
+  //
+  // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
+  // recursively compare the fields/elements, so it's also fine.
+  //
+  // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
+  // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
+  //
+  // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
+  // float/double columns and nested fields to `UnsafeRow`.
+  //
+  // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
+  // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
+  // types, so nested float/double may not be normalized. We need to make sure that all the unsafe
+  // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
+  // creation.
   protected final void writeFloat(long offset, float value) {
+    if (Float.isNaN(value)) {
+      value = Float.NaN;
+    } else if (value == -0.0f) {
+      value = 0.0f;
+    }
     Platform.putFloat(getBuffer(), offset, value);
   }
 
+  // See comments for `writeFloat`.
   protected final void writeDouble(long offset, double value) {
+    if (Double.isNaN(value)) {
+      value = Double.NaN;
+    } else if (value == -0.0d) {
+      value = 0.0d;
+    }
     Platform.putDouble(getBuffer(), offset, value);
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/33460c58/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
index fb651b7..22e1fa6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
@@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
     assert(res1 == res2)
   }
 
+  test("SPARK-26021: normalize float/double NaN and -0.0") {
+    val unsafeRowWriter1 = new UnsafeRowWriter(4)
+    unsafeRowWriter1.resetRowWriter()
+    unsafeRowWriter1.write(0, Float.NaN)
+    unsafeRowWriter1.write(1, Double.NaN)
+    unsafeRowWriter1.write(2, 0.0f)
+    unsafeRowWriter1.write(3, 0.0)
+    val res1 = unsafeRowWriter1.getRow
+
+    val unsafeRowWriter2 = new UnsafeRowWriter(4)
+    unsafeRowWriter2.resetRowWriter()
+    unsafeRowWriter2.write(0, 0.0f/0.0f)
+    unsafeRowWriter2.write(1, 0.0/0.0)
+    unsafeRowWriter2.write(2, -0.0f)
+    unsafeRowWriter2.write(3, -0.0)
+    val res2 = unsafeRowWriter2.getRow
+
+    // The two rows should be the equal
+    assert(res1 == res2)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/33460c58/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index e6b30f9..c9f41ab 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
       df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
     }
   }
+
+  test("NaN and -0.0 in join keys") {
+    val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
+    val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
+    val joined = df1.join(df2, Seq("f", "d"))
+    checkAnswer(joined, Seq(
+      Row(Float.NaN, Double.NaN),
+      Row(0.0f, 0.0),
+      Row(0.0f, 0.0),
+      Row(0.0f, 0.0),
+      Row(0.0f, 0.0)))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/33460c58/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 97a8439..bbeb1d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -658,4 +658,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
            |GROUP BY a
            |HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
   }
+
+  test("NaN and -0.0 in window partition keys") {
+    val df = Seq(
+      (Float.NaN, Double.NaN, 1),
+      (0.0f/0.0f, 0.0/0.0, 1),
+      (0.0f, 0.0, 1),
+      (-0.0f, -0.0, 1)).toDF("f", "d", "i")
+    val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
+    checkAnswer(result, Seq(
+      Row(Float.NaN, 2),
+      Row(Float.NaN, 2),
+      Row(0.0f, 2),
+      Row(0.0f, 2)))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org