You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2023/06/07 06:31:02 UTC

[spark] branch master updated: [SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs

This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5021638ee14 [SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs
5021638ee14 is described below

commit 5021638ee14758b92309942a1bcaed2b6554f810
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Wed Jun 7 14:30:42 2023 +0800

    [SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs
    
    ### What changes were proposed in this pull request?
    Scala client fails with NPE when running the following reduce agg:
    ```
    spark.range(0, 5, 1, 10).as[Long].reduce(_ + _) == 10
    ```
    The reason is because the `range` will produce null partitions and the Reduce encoder will not be able to set the default value correctly for partitions that contains Scala primitives. In the example, we expect 0 but receive null. This causes the codegen wrongly assumes the input is nullable and generates wrong code.
    
    ### Why are the changes needed?
    Bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit and Scala Client E2E tests.
    
    Closes #41264 from zhenlineo/fix-agg-null.
    
    Authored-by: Zhen Li <zh...@users.noreply.github.com>
    Signed-off-by: yangjie01 <ya...@baidu.com>
---
 .../spark/sql/UserDefinedFunctionE2ETestSuite.scala  | 20 ++++++++++++++++----
 .../spark/sql/expressions/ReduceAggregator.scala     | 13 ++++++++++++-
 .../sql/expressions/ReduceAggregatorSuite.scala      | 10 ++++++++--
 3 files changed, 36 insertions(+), 7 deletions(-)

diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index b5bbee67803..ca1bcf3fe67 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -198,18 +198,30 @@ class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
     assert(sum.get() == 0) // The value is not 45
   }
 
-  test("Dataset reduce") {
+  test("Dataset reduce without null partition inputs") {
     val session: SparkSession = spark
     import session.implicits._
-    assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55)
+    assert(spark.range(0, 10, 1, 5).map(_ + 1).reduce(_ + _) == 55)
   }
 
-  test("Dataset reduce - java") {
+  test("Dataset reduce with null partition inputs") {
+    val session: SparkSession = spark
+    import session.implicits._
+    assert(spark.range(0, 10, 1, 16).map(_ + 1).reduce(_ + _) == 55)
+  }
+
+  test("Dataset reduce with null partition inputs - java to scala long type") {
+    val session: SparkSession = spark
+    import session.implicits._
+    assert(spark.range(0, 5, 1, 10).as[Long].reduce(_ + _) == 10)
+  }
+
+  test("Dataset reduce with null partition inputs - java") {
     val session: SparkSession = spark
     import session.implicits._
     assert(
       spark
-        .range(10)
+        .range(0, 10, 1, 16)
         .map(_ + 1)
         .reduce(new ReduceFunction[Long] {
           override def call(v1: Long, v2: Long): Long = v1 + v2
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index 41306cd0a99..e897fdfe008 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -32,7 +32,18 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
 
   @transient private val encoder = implicitly[Encoder[T]]
 
-  override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
+  private val _zero = encoder.clsTag.runtimeClass match {
+    case java.lang.Boolean.TYPE => false
+    case java.lang.Byte.TYPE => 0.toByte
+    case java.lang.Short.TYPE => 0.toShort
+    case java.lang.Integer.TYPE => 0
+    case java.lang.Long.TYPE => 0L
+    case java.lang.Float.TYPE => 0f
+    case java.lang.Double.TYPE => 0d
+    case _ => null
+  }
+
+  override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T])
 
   override def bufferEncoder: Encoder[(Boolean, T)] =
     ExpressionEncoder.tuple(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
index f65dcdf119c..c1071373287 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
@@ -24,10 +24,16 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 class ReduceAggregatorSuite extends SparkFunSuite {
 
   test("zero value") {
-    val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
     val func = (v1: Int, v2: Int) => v1 + v2
     val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
-    assert(aggregator.zero == (false, null).asInstanceOf[(Boolean, Int)])
+    assert(aggregator.zero == (false, 0))
+  }
+
+  test("zero value boxed null") {
+    val func = (v1: java.lang.Integer, v2: java.lang.Integer) =>
+      (v1 + v2).asInstanceOf[java.lang.Integer]
+    val aggregator: ReduceAggregator[java.lang.Integer] = new ReduceAggregator(func)(Encoders.INT)
+    assert(aggregator.zero == (false, null).asInstanceOf[(Boolean, java.lang.Integer)])
   }
 
   test("reduce, merge and finish") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org