You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2023/06/07 06:31:02 UTC
[spark] branch master updated: [SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs
This is an automated email from the ASF dual-hosted git repository.
yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 5021638ee14 [SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs
5021638ee14 is described below
commit 5021638ee14758b92309942a1bcaed2b6554f810
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Wed Jun 7 14:30:42 2023 +0800
[SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs
### What changes were proposed in this pull request?
Scala client fails with NPE when running the following reduce agg:
```
spark.range(0, 5, 1, 10).as[Long].reduce(_ + _) == 10
```
The reason is because the `range` will produce null partitions and the Reduce encoder will not be able to set the default value correctly for partitions that contains Scala primitives. In the example, we expect 0 but receive null. This causes the codegen wrongly assumes the input is nullable and generates wrong code.
### Why are the changes needed?
Bug fix
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Unit and Scala Client E2E tests.
Closes #41264 from zhenlineo/fix-agg-null.
Authored-by: Zhen Li <zh...@users.noreply.github.com>
Signed-off-by: yangjie01 <ya...@baidu.com>
---
.../spark/sql/UserDefinedFunctionE2ETestSuite.scala | 20 ++++++++++++++++----
.../spark/sql/expressions/ReduceAggregator.scala | 13 ++++++++++++-
.../sql/expressions/ReduceAggregatorSuite.scala | 10 ++++++++--
3 files changed, 36 insertions(+), 7 deletions(-)
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
index b5bbee67803..ca1bcf3fe67 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala
@@ -198,18 +198,30 @@ class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession {
assert(sum.get() == 0) // The value is not 45
}
- test("Dataset reduce") {
+ test("Dataset reduce without null partition inputs") {
val session: SparkSession = spark
import session.implicits._
- assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55)
+ assert(spark.range(0, 10, 1, 5).map(_ + 1).reduce(_ + _) == 55)
}
- test("Dataset reduce - java") {
+ test("Dataset reduce with null partition inputs") {
+ val session: SparkSession = spark
+ import session.implicits._
+ assert(spark.range(0, 10, 1, 16).map(_ + 1).reduce(_ + _) == 55)
+ }
+
+ test("Dataset reduce with null partition inputs - java to scala long type") {
+ val session: SparkSession = spark
+ import session.implicits._
+ assert(spark.range(0, 5, 1, 10).as[Long].reduce(_ + _) == 10)
+ }
+
+ test("Dataset reduce with null partition inputs - java") {
val session: SparkSession = spark
import session.implicits._
assert(
spark
- .range(10)
+ .range(0, 10, 1, 16)
.map(_ + 1)
.reduce(new ReduceFunction[Long] {
override def call(v1: Long, v2: Long): Long = v1 + v2
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
index 41306cd0a99..e897fdfe008 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala
@@ -32,7 +32,18 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T)
@transient private val encoder = implicitly[Encoder[T]]
- override def zero: (Boolean, T) = (false, null.asInstanceOf[T])
+ private val _zero = encoder.clsTag.runtimeClass match {
+ case java.lang.Boolean.TYPE => false
+ case java.lang.Byte.TYPE => 0.toByte
+ case java.lang.Short.TYPE => 0.toShort
+ case java.lang.Integer.TYPE => 0
+ case java.lang.Long.TYPE => 0L
+ case java.lang.Float.TYPE => 0f
+ case java.lang.Double.TYPE => 0d
+ case _ => null
+ }
+
+ override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T])
override def bufferEncoder: Encoder[(Boolean, T)] =
ExpressionEncoder.tuple(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
index f65dcdf119c..c1071373287 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala
@@ -24,10 +24,16 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
class ReduceAggregatorSuite extends SparkFunSuite {
test("zero value") {
- val encoder: ExpressionEncoder[Int] = ExpressionEncoder()
val func = (v1: Int, v2: Int) => v1 + v2
val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt)
- assert(aggregator.zero == (false, null).asInstanceOf[(Boolean, Int)])
+ assert(aggregator.zero == (false, 0))
+ }
+
+ test("zero value boxed null") {
+ val func = (v1: java.lang.Integer, v2: java.lang.Integer) =>
+ (v1 + v2).asInstanceOf[java.lang.Integer]
+ val aggregator: ReduceAggregator[java.lang.Integer] = new ReduceAggregator(func)(Encoders.INT)
+ assert(aggregator.zero == (false, null).asInstanceOf[(Boolean, java.lang.Integer)])
}
test("reduce, merge and finish") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org