You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2017/07/11 03:20:04 UTC
spark git commit: [SPARK-21350][SQL] Fix the error message when the
number of arguments is wrong when invoking a UDF
Repository: spark
Updated Branches:
refs/heads/master a2bec6c92 -> 1471ee7af
[SPARK-21350][SQL] Fix the error message when the number of arguments is wrong when invoking a UDF
### What changes were proposed in this pull request?
Users get a very confusing error when users specify a wrong number of parameters.
```Scala
val df = spark.emptyDataFrame
spark.udf.register("foo", (_: String).length)
df.selectExpr("foo(2, 3, 4)")
```
```
org.apache.spark.sql.UDFSuite$$anonfun$9$$anonfun$apply$mcV$sp$12 cannot be cast to scala.Function3
java.lang.ClassCastException: org.apache.spark.sql.UDFSuite$$anonfun$9$$anonfun$apply$mcV$sp$12 cannot be cast to scala.Function3
at org.apache.spark.sql.catalyst.expressions.ScalaUDF.<init>(ScalaUDF.scala:109)
```
This PR is to capture the exception and issue an error message that is consistent with what we did for built-in functions. After the fix, the error message is improved to
```
Invalid number of arguments for function foo; line 1 pos 0
org.apache.spark.sql.AnalysisException: Invalid number of arguments for function foo; line 1 pos 0
at org.apache.spark.sql.catalyst.analysis.SimpleFunctionRegistry.lookupFunction(FunctionRegistry.scala:119)
```
### How was this patch tested?
Added a test case
Author: gatorsmile <ga...@gmail.com>
Closes #18574 from gatorsmile/statsCheck.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1471ee7a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1471ee7a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1471ee7a
Branch: refs/heads/master
Commit: 1471ee7af5a9952b60cf8c56d60cb6a7ec46cc69
Parents: a2bec6c
Author: gatorsmile <ga...@gmail.com>
Authored: Tue Jul 11 11:19:59 2017 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jul 11 11:19:59 2017 +0800
----------------------------------------------------------------------
.../org/apache/spark/sql/UDFRegistration.scala | 412 ++++++++++++++-----
.../test/org/apache/spark/sql/JavaUDFSuite.java | 8 +
.../scala/org/apache/spark/sql/UDFSuite.scala | 13 +-
3 files changed, 331 insertions(+), 102 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/1471ee7a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 8bdc022..c4d0adb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -111,7 +111,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try($inputTypes).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == $x) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: $x; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}""")
@@ -123,16 +128,20 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]"
val anyParams = (1 to i).map(_ => "_: Any").mkString(", ")
println(s"""
- |/**
- | * Register a user-defined function with ${i} arguments.
- | * @since 1.3.0
- | */
- |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = {
- | val func = f$anyCast.call($anyParams)
- | functionRegistry.createOrReplaceTempFunction(
- | name,
- | (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
- |}""".stripMargin)
+ |/**
+ | * Register a user-defined function with ${i} arguments.
+ | * @since 1.3.0
+ | */
+ |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = {
+ | val func = f$anyCast.call($anyParams)
+ |def builder(e: Seq[Expression]) = if (e.length == $i) {
+ | ScalaUDF(func, returnType, e)
+ |} else {
+ | throw new AnalysisException("Invalid number of arguments for function " + name +
+ | ". Expected: $i; Found: " + e.length)
+ |}
+ |functionRegistry.createOrReplaceTempFunction(name, builder)
+ |}""".stripMargin)
}
*/
@@ -144,7 +153,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 0) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 0; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -157,7 +171,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 1) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 1; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -170,7 +189,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 2) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 2; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -183,7 +207,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 3) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 3; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -196,7 +225,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 4) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 4; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -209,7 +243,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 5) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 5; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -222,7 +261,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 6) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 6; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -235,7 +279,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 7) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 7; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -248,7 +297,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 8) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 8; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -261,7 +315,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 9) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 9; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -274,7 +333,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 10) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 10; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -287,7 +351,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 11) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 11; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -300,7 +369,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 12) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 12; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -313,7 +387,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 13) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 13; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -326,7 +405,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 14) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 14; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -339,7 +423,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 15) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 15; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -352,7 +441,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 16) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 16; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -365,7 +459,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 17) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 17; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -378,7 +477,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 18) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 18; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -391,7 +495,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 19) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 19; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -404,7 +513,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 20) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 20; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -417,7 +531,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 21) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 21; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -430,7 +549,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = {
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption
- def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ def builder(e: Seq[Expression]) = if (e.length == 22) {
+ ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 22; Found: " + e.length)
+ }
functionRegistry.createOrReplaceTempFunction(name, builder)
UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable)
}
@@ -531,9 +655,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 1) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 1; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -542,9 +670,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 2) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 2; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -553,9 +685,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 3) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 3; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -564,9 +700,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 4) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 4; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -575,9 +715,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 5) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 5; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -586,9 +730,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 6) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 6; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -597,9 +745,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 7) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 7; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -608,9 +760,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 8) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 8; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -619,9 +775,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 9) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 9; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -630,9 +790,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 10) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 10; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -641,9 +805,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 11) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 11; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -652,9 +820,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 12) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 12; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -663,9 +835,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 13) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 13; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -674,9 +850,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 14) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 14; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -685,9 +865,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 15) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 15; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -696,9 +880,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 16) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 16; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -707,9 +895,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 17) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 17; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -718,9 +910,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 18) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 18; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -729,9 +925,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 19) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 19; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -740,9 +940,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 20) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 20; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -751,9 +955,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 21) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 21; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
/**
@@ -762,9 +970,13 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
*/
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
- functionRegistry.createOrReplaceTempFunction(
- name,
- (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
+ def builder(e: Seq[Expression]) = if (e.length == 22) {
+ ScalaUDF(func, returnType, e)
+ } else {
+ throw new AnalysisException("Invalid number of arguments for function " + name +
+ ". Expected: 22; Found: " + e.length)
+ }
+ functionRegistry.createOrReplaceTempFunction(name, builder)
}
// scalastyle:on line.size.limit
http://git-wip-us.apache.org/repos/asf/spark/blob/1471ee7a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 250fa67..4fb2988 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -25,6 +25,7 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF2;
@@ -105,4 +106,11 @@ public class JavaUDFSuite implements Serializable {
}
Assert.assertEquals(55, sum);
}
+
+ @SuppressWarnings("unchecked")
+ @Test(expected = AnalysisException.class)
+ public void udf5Test() {
+ spark.udf().register("inc", (Long i) -> i + 1, DataTypes.LongType);
+ List<Row> results = spark.sql("SELECT inc(1, 5)").collectAsList();
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/1471ee7a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index b4f744b..335b882 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -71,12 +71,21 @@ class UDFSuite extends QueryTest with SharedSQLContext {
}
}
- test("error reporting for incorrect number of arguments") {
+ test("error reporting for incorrect number of arguments - builtin function") {
val df = spark.emptyDataFrame
val e = intercept[AnalysisException] {
df.selectExpr("substr('abcd', 2, 3, 4)")
}
- assert(e.getMessage.contains("arguments"))
+ assert(e.getMessage.contains("Invalid number of arguments for function substr"))
+ }
+
+ test("error reporting for incorrect number of arguments - udf") {
+ val df = spark.emptyDataFrame
+ val e = intercept[AnalysisException] {
+ spark.udf.register("foo", (_: String).length)
+ df.selectExpr("foo(2, 3, 4)")
+ }
+ assert(e.getMessage.contains("Invalid number of arguments for function foo"))
}
test("error reporting for undefined functions") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org