You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ya...@apache.org on 2019/02/13 05:07:41 UTC
[spark] branch master updated: [SPARK-26798][SQL]
HandleNullInputsForUDF should trust nullability
This is an automated email from the ASF dual-hosted git repository.
yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new f502e20 [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability
f502e20 is described below
commit f502e209f49d3d76f947d1a8ba38c8c8a86e0bef
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Wed Feb 13 14:07:03 2019 +0900
[SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability
## What changes were proposed in this pull request?
There is a very old TODO in `HandleNullInputsForUDF`, saying that we can skip the null check if input is not nullable. We leverage the nullability info at many places, we can trust it here too.
## How was this patch tested?
re-enable an ignored test
Closes #23712 from cloud-fan/minor.
Lead-authored-by: Wenchen Fan <we...@databricks.com>
Co-authored-by: Xiao Li <ga...@gmail.com>
Signed-off-by: Takeshi Yamamuro <ya...@apache.org>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 47 ++++++++++++---------
.../spark/sql/catalyst/expressions/ScalaUDF.scala | 9 ++--
.../sql/catalyst/analysis/AnalysisSuite.scala | 36 +++++++++-------
.../sql/catalyst/expressions/ScalaUDFSuite.scala | 6 +--
.../spark/sql/catalyst/trees/TreeNodeSuite.scala | 2 +-
.../org/apache/spark/sql/UDFRegistration.scala | 48 +++++++++++-----------
.../datasources/FileFormatDataWriter.scala | 2 +-
.../sql/expressions/UserDefinedFunction.scala | 8 ++--
.../test/scala/org/apache/spark/sql/UDFSuite.scala | 37 ++++++++++++-----
9 files changed, 114 insertions(+), 81 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a84bb76..793c337 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2147,27 +2147,36 @@ class Analyzer(
case p => p transformExpressionsUp {
- case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _)
- if inputsNullSafe.contains(false) =>
+ case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _)
+ if inputPrimitives.contains(true) =>
// Otherwise, add special handling of null for fields that can't accept null.
// The result of operations like this, when passed null, is generally to return null.
- assert(inputsNullSafe.length == inputs.length)
-
- // TODO: skip null handling for not-nullable primitive inputs after we can completely
- // trust the `nullable` information.
- val inputsNullCheck = inputsNullSafe.zip(inputs)
- .filter { case (nullSafe, _) => !nullSafe }
- .map { case (_, expr) => IsNull(expr) }
- .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2))
- // Once we add an `If` check above the udf, it is safe to mark those checked inputs
- // as null-safe (i.e., set `inputsNullSafe` all `true`), because the null-returning
- // branch of `If` will be called if any of these checked inputs is null. Thus we can
- // prevent this rule from being applied repeatedly.
- val newInputsNullSafe = inputsNullSafe.map(_ => true)
- inputsNullCheck
- .map(If(_, Literal.create(null, udf.dataType),
- udf.copy(inputsNullSafe = newInputsNullSafe)))
- .getOrElse(udf)
+ assert(inputPrimitives.length == inputs.length)
+
+ val inputPrimitivesPair = inputPrimitives.zip(inputs)
+ val inputNullCheck = inputPrimitivesPair.collect {
+ case (isPrimitive, input) if isPrimitive && input.nullable =>
+ IsNull(input)
+ }.reduceLeftOption[Expression](Or)
+
+ if (inputNullCheck.isDefined) {
+ // Once we add an `If` check above the udf, it is safe to mark those checked inputs
+ // as null-safe (i.e., wrap with `KnownNotNull`), because the null-returning
+ // branch of `If` will be called if any of these checked inputs is null. Thus we can
+ // prevent this rule from being applied repeatedly.
+ val newInputs = inputPrimitivesPair.map {
+ case (isPrimitive, input) =>
+ if (isPrimitive && input.nullable) {
+ KnownNotNull(input)
+ } else {
+ input
+ }
+ }
+ val newUDF = udf.copy(children = newInputs)
+ If(inputNullCheck.get, Literal.create(null, udf.dataType), newUDF)
+ } else {
+ udf
+ }
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index c9e0a2e..eb45f08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -31,9 +31,10 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType}
* null. Use boxed type or [[Option]] if you wanna do the null-handling yourself.
* @param dataType Return type of function.
* @param children The input expressions of this UDF.
- * @param inputsNullSafe Whether the inputs are of non-primitive types or not nullable. Null values
- * of Scala primitive types will be converted to the type's default value and
- * lead to wrong results, thus need special handling before calling the UDF.
+ * @param inputPrimitives The analyzer should be aware of Scala primitive types so as to make the
+ * UDF return null if there is any null input value of these types. On the
+ * other hand, Java UDFs can only have boxed types, thus this parameter will
+ * always be all false.
* @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do
* not want to perform coercion, simply use "Nil". Note that it would've been
* better to use Option of Seq[DataType] so we can use "None" as the case for no
@@ -47,7 +48,7 @@ case class ScalaUDF(
function: AnyRef,
dataType: DataType,
children: Seq[Expression],
- inputsNullSafe: Seq[Boolean],
+ inputPrimitives: Seq[Boolean],
inputTypes: Seq[AbstractDataType] = Nil,
udfName: Option[String] = None,
nullable: Boolean = true,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 9829484..8038733 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -303,51 +303,57 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}
test("SPARK-11725: correctly handle null inputs for ScalaUDF") {
- val string = testRelation2.output(0)
- val double = testRelation2.output(2)
- val short = testRelation2.output(4)
+ val testRelation = LocalRelation(
+ AttributeReference("a", StringType)(),
+ AttributeReference("b", DoubleType)(),
+ AttributeReference("c", ShortType)(),
+ AttributeReference("d", DoubleType, nullable = false)())
+
+ val string = testRelation.output(0)
+ val double = testRelation.output(1)
+ val short = testRelation.output(2)
+ val nonNullableDouble = testRelation.output(3)
val nullResult = Literal.create(null, StringType)
def checkUDF(udf: Expression, transformed: Expression): Unit = {
checkAnalysis(
- Project(Alias(udf, "")() :: Nil, testRelation2),
- Project(Alias(transformed, "")() :: Nil, testRelation2)
+ Project(Alias(udf, "")() :: Nil, testRelation),
+ Project(Alias(transformed, "")() :: Nil, testRelation)
)
}
// non-primitive parameters do not need special null handling
- val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, true :: Nil)
+ val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, false :: Nil)
val expected1 = udf1
checkUDF(udf1, expected1)
// only primitive parameter needs special null handling
val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil,
- true :: false :: Nil)
+ false :: true :: Nil)
val expected2 =
- If(IsNull(double), nullResult, udf2.copy(inputsNullSafe = true :: true :: Nil))
+ If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil))
checkUDF(udf2, expected2)
// special null handling should apply to all primitive parameters
val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil,
- false :: false :: Nil)
+ true :: true :: Nil)
val expected3 = If(
IsNull(short) || IsNull(double),
nullResult,
- udf3.copy(inputsNullSafe = true :: true :: Nil))
+ udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil))
checkUDF(udf3, expected3)
// we can skip special null handling for primitive parameters that are not nullable
- // TODO: this is disabled for now as we can not completely trust `nullable`.
val udf4 = ScalaUDF(
(s: Short, d: Double) => "x",
StringType,
- short :: double.withNullability(false) :: Nil,
- false :: false :: Nil)
+ short :: nonNullableDouble :: Nil,
+ true :: true :: Nil)
val expected4 = If(
IsNull(short),
nullResult,
- udf4.copy(inputsNullSafe = true :: true :: Nil))
- // checkUDF(udf4, expected4)
+ udf4.copy(children = KnownNotNull(short) :: nonNullableDouble :: Nil))
+ checkUDF(udf4, expected4)
}
test("SPARK-24891 Fix HandleNullInputsForUDF rule") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
index 467cfd5..df92fa3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala
@@ -29,7 +29,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
checkEvaluation(intUdf, 2)
- val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
+ val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil)
checkEvaluation(stringUdf, "ax")
}
@@ -38,7 +38,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
(s: String) => s.toLowerCase(Locale.ROOT),
StringType,
Literal.create(null, StringType) :: Nil,
- true :: Nil)
+ false :: Nil)
val e1 = intercept[SparkException](udf.eval())
assert(e1.getMessage.contains("Failed to execute user defined function"))
@@ -51,7 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
- ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
+ ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx)
assert(ctx.inlinedMutableStates.isEmpty)
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 64aa1ee..cb911d7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -564,7 +564,7 @@ class TreeNodeSuite extends SparkFunSuite {
}
test("toJSON should not throws java.lang.StackOverflowError") {
- val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), true :: Nil)
+ val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), false :: Nil)
// Should not throw java.lang.StackOverflowError
udf.toJSON
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index fe5d1af..83425de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -151,7 +151,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
|def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = {
| val func = f$anyCast.call($anyParams)
| def builder(e: Seq[Expression]) = if (e.length == $i) {
- | ScalaUDF($funcCall, returnType, e, e.map(_ => true), udfName = Some(name))
+ | ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name))
| } else {
| throw new AnalysisException("Invalid number of arguments for function " + name +
| ". Expected: $i; Found: " + e.length)
@@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF0[_], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF0[Any]].call()
def builder(e: Seq[Expression]) = if (e.length == 0) {
- ScalaUDF(() => func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 0; Found: " + e.length)
@@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
def builder(e: Seq[Expression]) = if (e.length == 1) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 1; Found: " + e.length)
@@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 2) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 2; Found: " + e.length)
@@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 3) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 3; Found: " + e.length)
@@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 4) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 4; Found: " + e.length)
@@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 5) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 5; Found: " + e.length)
@@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 6) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 6; Found: " + e.length)
@@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 7) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 7; Found: " + e.length)
@@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 8) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 8; Found: " + e.length)
@@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 9) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 9; Found: " + e.length)
@@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 10) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 10; Found: " + e.length)
@@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 11) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 11; Found: " + e.length)
@@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 12) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 12; Found: " + e.length)
@@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 13) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 13; Found: " + e.length)
@@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 14) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 14; Found: " + e.length)
@@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 15) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 15; Found: " + e.length)
@@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 16) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 16; Found: " + e.length)
@@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 17) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 17; Found: " + e.length)
@@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 18) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 18; Found: " + e.length)
@@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 19) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 19; Found: " + e.length)
@@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 20) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 20; Found: " + e.length)
@@ -1034,7 +1034,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 21) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 21; Found: " + e.length)
@@ -1049,7 +1049,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
def builder(e: Seq[Expression]) = if (e.length == 22) {
- ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name))
+ ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name))
} else {
throw new AnalysisException("Invalid number of arguments for function " + name +
". Expected: 22; Found: " + e.length)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
index b6b0d7a..2595cc6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala
@@ -181,7 +181,7 @@ class DynamicPartitionDataWriter(
ExternalCatalogUtils.getPartitionPathString _,
StringType,
Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))),
- Seq(true, true))
+ Seq(false, false))
if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName)
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 4d8e1c5..0c956ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -100,14 +100,16 @@ private[sql] case class SparkUserDefinedFunction(
private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = {
// It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type
- // check and null check for them.
+ // check.
val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType))
- val inputsNullSafe = inputSchemas.map(_.map(_.nullable).getOrElse(true))
+ // `ScalaReflection.Schema.nullable` is false iff the type is primitive. Also `Any` is not
+ // primitive.
+ val inputsPrimitive = inputSchemas.map(_.map(!_.nullable).getOrElse(false))
ScalaUDF(
f,
dataType,
exprs,
- inputsNullSafe,
+ inputsPrimitive,
inputTypes,
udfName = name,
nullable = nullable,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 5ac2093..e515800 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -400,17 +400,23 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
test("SPARK-25044 Verify null input handling for primitive types - with udf()") {
- val udf1 = udf((x: Long, y: Any) => x * 2 + (if (y == null) 1 else 0))
- val df = spark.range(0, 3).toDF("a")
- .withColumn("b", udf1($"a", lit(null)))
- .withColumn("c", udf1(lit(null), $"a"))
-
- checkAnswer(
- df,
- Seq(
- Row(0, 1, null),
- Row(1, 3, null),
- Row(2, 5, null)))
+ val input = Seq(
+ (null, Integer.valueOf(1), "x"),
+ ("M", null, "y"),
+ ("N", Integer.valueOf(3), null)).toDF("a", "b", "c")
+
+ val udf1 = udf((a: String, b: Int, c: Any) => a + b + c)
+ val df = input.select(udf1('a, 'b, 'c))
+ checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
+
+ // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed.
+ val udf2 = udf(new UDF3[String, Integer, Object, String] {
+ override def call(t1: String, t2: Integer, t3: Object): String = {
+ t1 + t2 + t3
+ }
+ }, StringType)
+ val df2 = input.select(udf2('a, 'b, 'c))
+ checkAnswer(df2, Seq(Row("null1x"), Row("Mnully"), Row("N3null")))
}
test("SPARK-25044 Verify null input handling for primitive types - with udf.register") {
@@ -420,6 +426,15 @@ class UDFSuite extends QueryTest with SharedSQLContext {
spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c)
val df = spark.sql("SELECT f(a, b, c) FROM t")
checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
+
+ // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed.
+ spark.udf.register("f2", new UDF3[String, Integer, Object, String] {
+ override def call(t1: String, t2: Integer, t3: Object): String = {
+ t1 + t2 + t3
+ }
+ }, StringType)
+ val df2 = spark.sql("SELECT f2(a, b, c) FROM t")
+ checkAnswer(df2, Seq(Row("null1x"), Row("Mnully"), Row("N3null")))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org