You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/10/14 22:50:39 UTC
spark git commit: [SPARK-11775][PYSPARK][SQL] Allow PySpark to
register Java UDF
Repository: spark
Updated Branches:
refs/heads/master 5aeb7384c -> f00df40cf
[SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF
Currently pyspark can only call the builtin java UDF, but can not call custom java UDF. It would be better to allow that. 2 benefits:
* Leverage the power of rich third party java library
* Improve the performance. Because if we use python UDF, python daemons will be started on worker which will affect the performance.
Author: Jeff Zhang <zj...@apache.org>
Closes #9766 from zjffdu/SPARK-11775.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f00df40c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f00df40c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f00df40c
Branch: refs/heads/master
Commit: f00df40cfefef0f3fc73f16ada1006e4dcfa5a39
Parents: 5aeb738
Author: Jeff Zhang <zj...@apache.org>
Authored: Fri Oct 14 15:50:35 2016 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Oct 14 15:50:35 2016 -0700
----------------------------------------------------------------------
python/pyspark/sql/context.py | 28 +++++++-
.../spark/sql/catalyst/JavaTypeInference.scala | 2 +-
.../org/apache/spark/sql/UDFRegistration.scala | 75 +++++++++++++++++++-
.../org/apache/spark/sql/JavaStringLength.java | 30 ++++++++
.../test/org/apache/spark/sql/JavaUDFSuite.java | 21 ++++++
5 files changed, 152 insertions(+), 4 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 8264dcf..de4c335 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -28,7 +28,7 @@ from pyspark.sql.session import _monkey_patch_RDD, SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import Row, StringType
+from pyspark.sql.types import IntegerType, Row, StringType
from pyspark.sql.utils import install_exception_handler
__all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
@@ -202,6 +202,32 @@ class SQLContext(object):
"""
self.sparkSession.catalog.registerFunction(name, f, returnType)
+ @ignore_unicode_prefix
+ @since(2.1)
+ def registerJavaFunction(self, name, javaClassName, returnType=None):
+ """Register a java UDF so it can be used in SQL statements.
+
+ In addition to a name and the function itself, the return type can be optionally specified.
+ When the return type is not specified we would infer it via reflection.
+ :param name: name of the UDF
+ :param javaClassName: fully qualified name of java class
+ :param returnType: a :class:`pyspark.sql.types.DataType` object
+
+ >>> sqlContext.registerJavaFunction("javaStringLength",
+ ... "test.org.apache.spark.sql.JavaStringLength", IntegerType())
+ >>> sqlContext.sql("SELECT javaStringLength('test')").collect()
+ [Row(UDF(test)=4)]
+ >>> sqlContext.registerJavaFunction("javaStringLength2",
+ ... "test.org.apache.spark.sql.JavaStringLength")
+ >>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
+ [Row(UDF(test)=4)]
+
+ """
+ jdt = None
+ if returnType is not None:
+ jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
+ self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
+
# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""
http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index e6f61b0..04f0cfc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -59,7 +59,7 @@ object JavaTypeInference {
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
- private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
+ private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 617a147..0444ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -17,19 +17,25 @@
package org.apache.spark.sql
+import java.io.IOException
+import java.lang.reflect.{ParameterizedType, Type}
+
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
+import com.google.common.reflect.TypeToken
+
import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
+import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{DataType, DataTypes}
+import org.apache.spark.util.Utils
/**
* Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
@@ -414,6 +420,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
//////////////////////////////////////////////////////////////////////////////////////////////
/**
+ * Register a Java UDF class using reflection, for use from pyspark
+ *
+ * @param name udf name
+ * @param className fully qualified class name of udf
+ * @param returnDataType return type of udf. If it is null, spark would try to infer
+ * via reflection.
+ */
+ private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
+
+ try {
+ val clazz = Utils.classForName(className)
+ val udfInterfaces = clazz.getGenericInterfaces
+ .filter(_.isInstanceOf[ParameterizedType])
+ .map(_.asInstanceOf[ParameterizedType])
+ .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
+ if (udfInterfaces.length == 0) {
+ throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
+ } else if (udfInterfaces.length > 1) {
+ throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
+ } else {
+ try {
+ val udf = clazz.newInstance()
+ val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
+ var returnType = returnDataType
+ if (returnType == null) {
+ returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1
+ }
+
+ udfInterfaces(0).getActualTypeArguments.length match {
+ case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
+ case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
+ case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
+ case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType)
+ case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType)
+ case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType)
+ case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType)
+ case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType)
+ case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType)
+ case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+ case n => logError(s"UDF class with ${n} type arguments is not supported ")
+ }
+ } catch {
+ case e @ (_: InstantiationException | _: IllegalArgumentException) =>
+ logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
+ }
+ }
+ } catch {
+ case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
+ }
+
+ }
+
+ /**
* Register a user-defined function with 1 arguments.
* @since 1.3.0
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
new file mode 100644
index 0000000..b90224f
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import org.apache.spark.sql.api.java.UDF1;
+
+/**
+ * It is used for register Java UDF from PySpark
+ */
+public class JavaStringLength implements UDF1<String, Integer> {
+ @Override
+ public Integer call(String str) throws Exception {
+ return new Integer(str.length());
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 2274912..8bf3278 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -87,4 +87,25 @@ public class JavaUDFSuite implements Serializable {
Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
Assert.assertEquals(9, result.getInt(0));
}
+
+ public static class StringLengthTest implements UDF2<String, String, Integer> {
+ @Override
+ public Integer call(String str1, String str2) throws Exception {
+ return new Integer(str1.length() + str2.length());
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void udf3Test() {
+ spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
+ DataTypes.IntegerType);
+ Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
+ Assert.assertEquals(9, result.getInt(0));
+
+ // returnType is not provided
+ spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
+ result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
+ Assert.assertEquals(9, result.getInt(0));
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org