You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/12/08 19:51:12 UTC

[spark] branch branch-3.0 updated: [SPARK-32110][SQL] normalize special floating numbers in HyperLogLog++

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new eae6a3e  [SPARK-32110][SQL] normalize special floating numbers in HyperLogLog++
eae6a3e is described below

commit eae6a3e9dc75912e2fbe80d86f05cce629de8022
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Tue Dec 8 11:41:35 2020 -0800

    [SPARK-32110][SQL] normalize special floating numbers in HyperLogLog++
    
    ### What changes were proposed in this pull request?
    
    Currently, Spark treats 0.0 and -0.0 semantically equal, while it still retains the difference between them so that users can see -0.0 when displaying the data set.
    
    The comparison expressions in Spark take care of the special floating numbers and implement the correct semantic. However, Spark doesn't always use these comparison expressions to compare values, and we need to normalize the special floating numbers before comparing them in these places:
    1. GROUP BY
    2. join keys
    3. window partition keys
    
    This PR fixes one more place that compares values without using comparison expressions: HyperLogLog++
    
    ### Why are the changes needed?
    
    Fix the query result
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the result of HyperLogLog++ becomes correct now.
    
    ### How was this patch tested?
    
    a new test case, and a few more test cases that pass before this PR to improve test coverage.
    
    Closes #30673 from cloud-fan/bug.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
    (cherry picked from commit 6fd234503cf1e85715ccd3bda42f29dae1daa71b)
    Signed-off-by: Dongjoon Hyun <do...@apache.org>
---
 .../optimizer/NormalizeFloatingNumbers.scala       | 45 ++++++-----
 .../catalyst/util/HyperLogLogPlusPlusHelper.scala  |  8 +-
 .../sql/catalyst/expressions/PredicateSuite.scala  | 90 ++++++++++++++++++++++
 .../aggregate/HyperLogLogPlusPlusSuite.scala       | 24 +++++-
 4 files changed, 144 insertions(+), 23 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index f0cf671..59265221 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -134,6 +134,28 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
 
     case _ => throw new IllegalStateException(s"fail to normalize $expr")
   }
+
+  val FLOAT_NORMALIZER: Any => Any = (input: Any) => {
+    val f = input.asInstanceOf[Float]
+    if (f.isNaN) {
+      Float.NaN
+    } else if (f == -0.0f) {
+      0.0f
+    } else {
+      f
+    }
+  }
+
+  val DOUBLE_NORMALIZER: Any => Any = (input: Any) => {
+    val d = input.asInstanceOf[Double]
+    if (d.isNaN) {
+      Double.NaN
+    } else if (d == -0.0d) {
+      0.0d
+    } else {
+      d
+    }
+  }
 }
 
 case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes {
@@ -143,27 +165,8 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E
   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType))
 
   private lazy val normalizer: Any => Any = child.dataType match {
-    case FloatType => (input: Any) => {
-      val f = input.asInstanceOf[Float]
-      if (f.isNaN) {
-        Float.NaN
-      } else if (f == -0.0f) {
-        0.0f
-      } else {
-        f
-      }
-    }
-
-    case DoubleType => (input: Any) => {
-      val d = input.asInstanceOf[Double]
-      if (d.isNaN) {
-        Double.NaN
-      } else if (d == -0.0d) {
-        0.0d
-      } else {
-        d
-      }
-    }
+    case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER
+    case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER
   }
 
   override def nullSafeEval(input: Any): Any = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
index ea619c6..6471a74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/HyperLogLogPlusPlusHelper.scala
@@ -22,6 +22,7 @@ import java.util
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.XxHash64Function
+import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER, FLOAT_NORMALIZER}
 import org.apache.spark.sql.types._
 
 // A helper class for HyperLogLogPlusPlus.
@@ -88,7 +89,12 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable {
    *
    * Variable names in the HLL++ paper match variable names in the code.
    */
-  def update(buffer: InternalRow, bufferOffset: Int, value: Any, dataType: DataType): Unit = {
+  def update(buffer: InternalRow, bufferOffset: Int, _value: Any, dataType: DataType): Unit = {
+    val value = dataType match {
+      case FloatType => FLOAT_NORMALIZER.apply(_value)
+      case DoubleType => DOUBLE_NORMALIZER.apply(_value)
+      case _ => _value
+    }
     // Create the hashed value 'x'.
     val x = XxHash64Function.hash(value, dataType, 42L)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index a36baec..6f75623 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -554,4 +554,94 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false)
     checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false)
   }
+
+  test("SPARK-32110: compare special double/float values in array") {
+    def createUnsafeDoubleArray(d: Double): Literal = {
+      Literal(UnsafeArrayData.fromPrimitiveArray(Array(d)), ArrayType(DoubleType))
+    }
+    def createSafeDoubleArray(d: Double): Literal = {
+      Literal(new GenericArrayData(Array(d)), ArrayType(DoubleType))
+    }
+    def createUnsafeFloatArray(d: Double): Literal = {
+      Literal(UnsafeArrayData.fromPrimitiveArray(Array(d.toFloat)), ArrayType(FloatType))
+    }
+    def createSafeFloatArray(d: Double): Literal = {
+      Literal(new GenericArrayData(Array(d.toFloat)), ArrayType(FloatType))
+    }
+    def checkExpr(
+        exprBuilder: (Expression, Expression) => Expression,
+        left: Double,
+        right: Double,
+        expected: Any): Unit = {
+      // test double
+      checkEvaluation(
+        exprBuilder(createUnsafeDoubleArray(left), createUnsafeDoubleArray(right)), expected)
+      checkEvaluation(
+        exprBuilder(createUnsafeDoubleArray(left), createSafeDoubleArray(right)), expected)
+      checkEvaluation(
+        exprBuilder(createSafeDoubleArray(left), createSafeDoubleArray(right)), expected)
+      // test float
+      checkEvaluation(
+        exprBuilder(createUnsafeFloatArray(left), createUnsafeFloatArray(right)), expected)
+      checkEvaluation(
+        exprBuilder(createUnsafeFloatArray(left), createSafeFloatArray(right)), expected)
+      checkEvaluation(
+        exprBuilder(createSafeFloatArray(left), createSafeFloatArray(right)), expected)
+    }
+
+    checkExpr(EqualTo, Double.NaN, Double.NaN, true)
+    checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
+    checkExpr(EqualTo, 0.0, -0.0, true)
+    checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
+    checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
+    checkExpr(GreaterThan, 0.0, -0.0, false)
+  }
+
+  test("SPARK-32110: compare special double/float values in struct") {
+    def createUnsafeDoubleRow(d: Double): Literal = {
+      val dt = new StructType().add("d", "double")
+      val converter = UnsafeProjection.create(dt)
+      val unsafeRow = converter.apply(InternalRow(d))
+      Literal(unsafeRow, dt)
+    }
+    def createSafeDoubleRow(d: Double): Literal = {
+      Literal(InternalRow(d), new StructType().add("d", "double"))
+    }
+    def createUnsafeFloatRow(d: Double): Literal = {
+      val dt = new StructType().add("f", "float")
+      val converter = UnsafeProjection.create(dt)
+      val unsafeRow = converter.apply(InternalRow(d.toFloat))
+      Literal(unsafeRow, dt)
+    }
+    def createSafeFloatRow(d: Double): Literal = {
+      Literal(InternalRow(d.toFloat), new StructType().add("f", "float"))
+    }
+    def checkExpr(
+        exprBuilder: (Expression, Expression) => Expression,
+        left: Double,
+        right: Double,
+        expected: Any): Unit = {
+      // test double
+      checkEvaluation(
+        exprBuilder(createUnsafeDoubleRow(left), createUnsafeDoubleRow(right)), expected)
+      checkEvaluation(
+        exprBuilder(createUnsafeDoubleRow(left), createSafeDoubleRow(right)), expected)
+      checkEvaluation(
+        exprBuilder(createSafeDoubleRow(left), createSafeDoubleRow(right)), expected)
+      // test float
+      checkEvaluation(
+        exprBuilder(createUnsafeFloatRow(left), createUnsafeFloatRow(right)), expected)
+      checkEvaluation(
+        exprBuilder(createUnsafeFloatRow(left), createSafeFloatRow(right)), expected)
+      checkEvaluation(
+        exprBuilder(createSafeFloatRow(left), createSafeFloatRow(right)), expected)
+    }
+
+    checkExpr(EqualTo, Double.NaN, Double.NaN, true)
+    checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
+    checkExpr(EqualTo, 0.0, -0.0, true)
+    checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
+    checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
+    checkExpr(GreaterThan, 0.0, -0.0, false)
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
index 98fd04c..1afccea 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
+import java.lang.{Double => JDouble}
 import java.util.Random
 
 import scala.collection.mutable
@@ -24,7 +25,7 @@ import scala.collection.mutable
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow}
-import org.apache.spark.sql.types.{DataType, IntegerType}
+import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType}
 
 class HyperLogLogPlusPlusSuite extends SparkFunSuite {
 
@@ -153,4 +154,25 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
     // Check if the buffers are equal.
     assert(buffer2 == buffer1a, "Buffers should be equal")
   }
+
+  test("SPARK-32110: add 0.0 and -0.0") {
+    val (hll, input, buffer) = createEstimator(0.05, DoubleType)
+    input.setDouble(0, 0.0)
+    hll.update(buffer, input)
+    input.setDouble(0, -0.0)
+    hll.update(buffer, input)
+    evaluateEstimate(hll, buffer, 1);
+  }
+
+  test("SPARK-32110: add NaN") {
+    val (hll, input, buffer) = createEstimator(0.05, DoubleType)
+    input.setDouble(0, Double.NaN)
+    hll.update(buffer, input)
+    val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L)
+    assert(JDouble.isNaN(specialNaN))
+    assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN))
+    input.setDouble(0, specialNaN)
+    hll.update(buffer, input)
+    evaluateEstimate(hll, buffer, 1);
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org