You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2016/10/14 22:50:39 UTC

spark git commit: [SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF

Repository: spark
Updated Branches:
  refs/heads/master 5aeb7384c -> f00df40cf


[SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF

Currently pyspark can only call the builtin java UDF, but can not call custom java UDF. It would be better to allow that. 2 benefits:
* Leverage the power of rich third party java library
* Improve the performance. Because if we use python UDF, python daemons will be started on worker which will affect the performance.

Author: Jeff Zhang <zj...@apache.org>

Closes #9766 from zjffdu/SPARK-11775.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f00df40c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f00df40c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f00df40c

Branch: refs/heads/master
Commit: f00df40cfefef0f3fc73f16ada1006e4dcfa5a39
Parents: 5aeb738
Author: Jeff Zhang <zj...@apache.org>
Authored: Fri Oct 14 15:50:35 2016 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Oct 14 15:50:35 2016 -0700

----------------------------------------------------------------------
 python/pyspark/sql/context.py                   | 28 +++++++-
 .../spark/sql/catalyst/JavaTypeInference.scala  |  2 +-
 .../org/apache/spark/sql/UDFRegistration.scala  | 75 +++++++++++++++++++-
 .../org/apache/spark/sql/JavaStringLength.java  | 30 ++++++++
 .../test/org/apache/spark/sql/JavaUDFSuite.java | 21 ++++++
 5 files changed, 152 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 8264dcf..de4c335 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -28,7 +28,7 @@ from pyspark.sql.session import _monkey_patch_RDD, SparkSession
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.readwriter import DataFrameReader
 from pyspark.sql.streaming import DataStreamReader
-from pyspark.sql.types import Row, StringType
+from pyspark.sql.types import IntegerType, Row, StringType
 from pyspark.sql.utils import install_exception_handler
 
 __all__ = ["SQLContext", "HiveContext", "UDFRegistration"]
@@ -202,6 +202,32 @@ class SQLContext(object):
         """
         self.sparkSession.catalog.registerFunction(name, f, returnType)
 
+    @ignore_unicode_prefix
+    @since(2.1)
+    def registerJavaFunction(self, name, javaClassName, returnType=None):
+        """Register a java UDF so it can be used in SQL statements.
+
+        In addition to a name and the function itself, the return type can be optionally specified.
+        When the return type is not specified we would infer it via reflection.
+        :param name:  name of the UDF
+        :param javaClassName: fully qualified name of java class
+        :param returnType: a :class:`pyspark.sql.types.DataType` object
+
+        >>> sqlContext.registerJavaFunction("javaStringLength",
+        ...   "test.org.apache.spark.sql.JavaStringLength", IntegerType())
+        >>> sqlContext.sql("SELECT javaStringLength('test')").collect()
+        [Row(UDF(test)=4)]
+        >>> sqlContext.registerJavaFunction("javaStringLength2",
+        ...   "test.org.apache.spark.sql.JavaStringLength")
+        >>> sqlContext.sql("SELECT javaStringLength2('test')").collect()
+        [Row(UDF(test)=4)]
+
+        """
+        jdt = None
+        if returnType is not None:
+            jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
+        self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
+
     # TODO(andrew): delete this once we refactor things to take in SparkSession
     def _inferSchema(self, rdd, samplingRatio=None):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index e6f61b0..04f0cfc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -59,7 +59,7 @@ object JavaTypeInference {
    * @param typeToken Java type
    * @return (SQL data type, nullable)
    */
-  private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
+  private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
     typeToken.getRawType match {
       case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
         (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 617a147..0444ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -17,19 +17,25 @@
 
 package org.apache.spark.sql
 
+import java.io.IOException
+import java.lang.reflect.{ParameterizedType, Type}
+
 import scala.reflect.runtime.universe.TypeTag
 import scala.util.Try
 
+import com.google.common.reflect.TypeToken
+
 import org.apache.spark.annotation.InterfaceStability
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.api.java._
+import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
-import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
 import org.apache.spark.sql.execution.aggregate.ScalaUDAF
 import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
 import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{DataType, DataTypes}
+import org.apache.spark.util.Utils
 
 /**
  * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
@@ -414,6 +420,71 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
   //////////////////////////////////////////////////////////////////////////////////////////////
 
   /**
+   * Register a Java UDF class using reflection, for use from pyspark
+   *
+   * @param name   udf name
+   * @param className   fully qualified class name of udf
+   * @param returnDataType  return type of udf. If it is null, spark would try to infer
+   *                        via reflection.
+   */
+  private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
+
+    try {
+      val clazz = Utils.classForName(className)
+      val udfInterfaces = clazz.getGenericInterfaces
+        .filter(_.isInstanceOf[ParameterizedType])
+        .map(_.asInstanceOf[ParameterizedType])
+        .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
+      if (udfInterfaces.length == 0) {
+        throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
+      } else if (udfInterfaces.length > 1) {
+        throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
+      } else {
+        try {
+          val udf = clazz.newInstance()
+          val udfReturnType = udfInterfaces(0).getActualTypeArguments.last
+          var returnType = returnDataType
+          if (returnType == null) {
+            returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1
+          }
+
+          udfInterfaces(0).getActualTypeArguments.length match {
+            case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType)
+            case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType)
+            case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType)
+            case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType)
+            case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType)
+            case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType)
+            case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType)
+            case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType)
+            case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType)
+            case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
+            case n => logError(s"UDF class with ${n} type arguments is not supported ")
+          }
+        } catch {
+          case e @ (_: InstantiationException | _: IllegalArgumentException) =>
+            logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
+        }
+      }
+    } catch {
+      case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
+    }
+
+  }
+
+  /**
    * Register a user-defined function with 1 arguments.
    * @since 1.3.0
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
new file mode 100644
index 0000000..b90224f
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import org.apache.spark.sql.api.java.UDF1;
+
+/**
+ * It is used for register Java UDF from PySpark
+ */
+public class JavaStringLength implements UDF1<String, Integer> {
+  @Override
+  public Integer call(String str) throws Exception {
+    return new Integer(str.length());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f00df40c/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 2274912..8bf3278 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -87,4 +87,25 @@ public class JavaUDFSuite implements Serializable {
     Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
     Assert.assertEquals(9, result.getInt(0));
   }
+
+  public static class StringLengthTest implements UDF2<String, String, Integer> {
+    @Override
+    public Integer call(String str1, String str2) throws Exception {
+      return new Integer(str1.length() + str2.length());
+    }
+  }
+
+  @SuppressWarnings("unchecked")
+  @Test
+  public void udf3Test() {
+    spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(),
+        DataTypes.IntegerType);
+    Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
+    Assert.assertEquals(9, result.getInt(0));
+
+    // returnType is not provided
+    spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null);
+    result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
+    Assert.assertEquals(9, result.getInt(0));
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org