You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/02/05 00:55:12 UTC

spark git commit: [SPARK-5577] Python udf for DataFrame

Repository: spark
Updated Branches:
  refs/heads/master e0490e271 -> dc101b0e4


[SPARK-5577] Python udf for DataFrame

Author: Davies Liu <da...@databricks.com>

Closes #4351 from davies/python_udf and squashes the following commits:

d250692 [Davies Liu] fix conflict
34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
440f769 [Davies Liu] address comments
f0a3121 [Davies Liu] track life cycle of broadcast
f99b2e1 [Davies Liu] address comments
462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
7bccc3b [Davies Liu] python udf
58dee20 [Davies Liu] clean up


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/dc101b0e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/dc101b0e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/dc101b0e

Branch: refs/heads/master
Commit: dc101b0e4e23dffddbc2f70d14a19fae5d87a328
Parents: e0490e2
Author: Davies Liu <da...@databricks.com>
Authored: Wed Feb 4 15:55:09 2015 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Feb 4 15:55:09 2015 -0800

----------------------------------------------------------------------
 python/pyspark/rdd.py                           |  38 ++--
 python/pyspark/sql.py                           | 195 +++++++++----------
 .../scala/org/apache/spark/sql/Column.scala     |  19 +-
 .../apache/spark/sql/UserDefinedFunction.scala  |  27 +++
 4 files changed, 157 insertions(+), 122 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/dc101b0e/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 2f8a0ed..6e029bf 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2162,6 +2162,25 @@ class RDD(object):
                 yield row
 
 
+def _prepare_for_python_RDD(sc, command, obj=None):
+    # the serialized command will be compressed by broadcast
+    ser = CloudPickleSerializer()
+    pickled_command = ser.dumps(command)
+    if len(pickled_command) > (1 << 20):  # 1M
+        broadcast = sc.broadcast(pickled_command)
+        pickled_command = ser.dumps(broadcast)
+        # tracking the life cycle by obj
+        if obj is not None:
+            obj._broadcast = broadcast
+    broadcast_vars = ListConverter().convert(
+        [x._jbroadcast for x in sc._pickled_broadcast_vars],
+        sc._gateway._gateway_client)
+    sc._pickled_broadcast_vars.clear()
+    env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
+    includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
+    return pickled_command, broadcast_vars, env, includes
+
+
 class PipelinedRDD(RDD):
 
     """
@@ -2228,25 +2247,12 @@ class PipelinedRDD(RDD):
 
         command = (self.func, profiler, self._prev_jrdd_deserializer,
                    self._jrdd_deserializer)
-        # the serialized command will be compressed by broadcast
-        ser = CloudPickleSerializer()
-        pickled_command = ser.dumps(command)
-        if len(pickled_command) > (1 << 20):  # 1M
-            self._broadcast = self.ctx.broadcast(pickled_command)
-            pickled_command = ser.dumps(self._broadcast)
-        broadcast_vars = ListConverter().convert(
-            [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
-            self.ctx._gateway._gateway_client)
-        self.ctx._pickled_broadcast_vars.clear()
-        env = MapConverter().convert(self.ctx.environment,
-                                     self.ctx._gateway._gateway_client)
-        includes = ListConverter().convert(self.ctx._python_includes,
-                                           self.ctx._gateway._gateway_client)
+        pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
         python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
-                                             bytearray(pickled_command),
+                                             bytearray(pickled_cmd),
                                              env, includes, self.preservesPartitioning,
                                              self.ctx.pythonExec,
-                                             broadcast_vars, self.ctx._javaAccumulator)
+                                             bvars, self.ctx._javaAccumulator)
         self._jrdd_val = python_rdd.asJavaRDD()
 
         if profiler:

http://git-wip-us.apache.org/repos/asf/spark/blob/dc101b0e/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index a266cde..5b56b36 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -51,7 +51,7 @@ from py4j.protocol import Py4JError
 from py4j.java_collections import ListConverter, MapConverter
 
 from pyspark.context import SparkContext
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _prepare_for_python_RDD
 from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
     CloudPickleSerializer, UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
@@ -1274,28 +1274,15 @@ class SQLContext(object):
         [Row(c0=4)]
         """
         func = lambda _, it: imap(lambda x: f(*x), it)
-        command = (func, None,
-                   AutoBatchedSerializer(PickleSerializer()),
-                   AutoBatchedSerializer(PickleSerializer()))
-        ser = CloudPickleSerializer()
-        pickled_command = ser.dumps(command)
-        if len(pickled_command) > (1 << 20):  # 1M
-            broadcast = self._sc.broadcast(pickled_command)
-            pickled_command = ser.dumps(broadcast)
-        broadcast_vars = ListConverter().convert(
-            [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
-            self._sc._gateway._gateway_client)
-        self._sc._pickled_broadcast_vars.clear()
-        env = MapConverter().convert(self._sc.environment,
-                                     self._sc._gateway._gateway_client)
-        includes = ListConverter().convert(self._sc._python_includes,
-                                           self._sc._gateway._gateway_client)
+        ser = AutoBatchedSerializer(PickleSerializer())
+        command = (func, None, ser, ser)
+        pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
         self._ssql_ctx.udf().registerPython(name,
-                                            bytearray(pickled_command),
+                                            bytearray(pickled_cmd),
                                             env,
                                             includes,
                                             self._sc.pythonExec,
-                                            broadcast_vars,
+                                            bvars,
                                             self._sc._javaAccumulator,
                                             returnType.json())
 
@@ -2077,9 +2064,9 @@ class DataFrame(object):
         """Return all column names and their data types as a list.
 
         >>> df.dtypes
-        [(u'age', 'IntegerType'), (u'name', 'StringType')]
+        [('age', 'integer'), ('name', 'string')]
         """
-        return [(f.name, str(f.dataType)) for f in self.schema().fields]
+        return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
 
     @property
     def columns(self):
@@ -2194,7 +2181,7 @@ class DataFrame(object):
         [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         >>> df.select('name', 'age').collect()
         [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
-        >>> df.select(df.name, (df.age + 10).As('age')).collect()
+        >>> df.select(df.name, (df.age + 10).alias('age')).collect()
         [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
         """
         if not cols:
@@ -2295,25 +2282,13 @@ class DataFrame(object):
         """
         return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
 
-    def sample(self, withReplacement, fraction, seed=None):
-        """ Return a new DataFrame by sampling a fraction of rows.
-
-        >>> df.sample(False, 0.5, 10).collect()
-        [Row(age=2, name=u'Alice')]
-        """
-        if seed is None:
-            jdf = self._jdf.sample(withReplacement, fraction)
-        else:
-            jdf = self._jdf.sample(withReplacement, fraction, seed)
-        return DataFrame(jdf, self.sql_ctx)
-
     def addColumn(self, colName, col):
         """ Return a new :class:`DataFrame` by adding a column.
 
         >>> df.addColumn('age2', df.age + 2).collect()
         [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
         """
-        return self.select('*', col.As(colName))
+        return self.select('*', col.alias(colName))
 
 
 # Having SchemaRDD for backward compatibility (for docs)
@@ -2408,28 +2383,6 @@ class GroupedDataFrame(object):
         group."""
 
 
-SCALA_METHOD_MAPPINGS = {
-    '=': '$eq',
-    '>': '$greater',
-    '<': '$less',
-    '+': '$plus',
-    '-': '$minus',
-    '*': '$times',
-    '/': '$div',
-    '!': '$bang',
-    '@': '$at',
-    '#': '$hash',
-    '%': '$percent',
-    '^': '$up',
-    '&': '$amp',
-    '~': '$tilde',
-    '?': '$qmark',
-    '|': '$bar',
-    '\\': '$bslash',
-    ':': '$colon',
-}
-
-
 def _create_column_from_literal(literal):
     sc = SparkContext._active_spark_context
     return sc._jvm.Dsl.lit(literal)
@@ -2448,23 +2401,18 @@ def _to_java_column(col):
     return jcol
 
 
-def _scalaMethod(name):
-    """ Translate operators into methodName in Scala
-
-    >>> _scalaMethod('+')
-    '$plus'
-    >>> _scalaMethod('>=')
-    '$greater$eq'
-    >>> _scalaMethod('cast')
-    'cast'
-    """
-    return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
-
-
 def _unary_op(name, doc="unary operator"):
     """ Create a method for given unary operator """
     def _(self):
-        jc = getattr(self._jc, _scalaMethod(name))()
+        jc = getattr(self._jc, name)()
+        return Column(jc, self.sql_ctx)
+    _.__doc__ = doc
+    return _
+
+
+def _dsl_op(name, doc=''):
+    def _(self):
+        jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
         return Column(jc, self.sql_ctx)
     _.__doc__ = doc
     return _
@@ -2475,7 +2423,7 @@ def _bin_op(name, doc="binary operator"):
     """
     def _(self, other):
         jc = other._jc if isinstance(other, Column) else other
-        njc = getattr(self._jc, _scalaMethod(name))(jc)
+        njc = getattr(self._jc, name)(jc)
         return Column(njc, self.sql_ctx)
     _.__doc__ = doc
     return _
@@ -2486,7 +2434,7 @@ def _reverse_op(name, doc="binary operator"):
     """
     def _(self, other):
         jother = _create_column_from_literal(other)
-        jc = getattr(jother, _scalaMethod(name))(self._jc)
+        jc = getattr(jother, name)(self._jc)
         return Column(jc, self.sql_ctx)
     _.__doc__ = doc
     return _
@@ -2513,34 +2461,33 @@ class Column(DataFrame):
         super(Column, self).__init__(jc, sql_ctx)
 
     # arithmetic operators
-    __neg__ = _unary_op("unary_-")
-    __add__ = _bin_op("+")
-    __sub__ = _bin_op("-")
-    __mul__ = _bin_op("*")
-    __div__ = _bin_op("/")
-    __mod__ = _bin_op("%")
-    __radd__ = _bin_op("+")
-    __rsub__ = _reverse_op("-")
-    __rmul__ = _bin_op("*")
-    __rdiv__ = _reverse_op("/")
-    __rmod__ = _reverse_op("%")
-    __abs__ = _unary_op("abs")
+    __neg__ = _dsl_op("negate")
+    __add__ = _bin_op("plus")
+    __sub__ = _bin_op("minus")
+    __mul__ = _bin_op("multiply")
+    __div__ = _bin_op("divide")
+    __mod__ = _bin_op("mod")
+    __radd__ = _bin_op("plus")
+    __rsub__ = _reverse_op("minus")
+    __rmul__ = _bin_op("multiply")
+    __rdiv__ = _reverse_op("divide")
+    __rmod__ = _reverse_op("mod")
 
     # logistic operators
-    __eq__ = _bin_op("===")
-    __ne__ = _bin_op("!==")
-    __lt__ = _bin_op("<")
-    __le__ = _bin_op("<=")
-    __ge__ = _bin_op(">=")
-    __gt__ = _bin_op(">")
+    __eq__ = _bin_op("equalTo")
+    __ne__ = _bin_op("notEqual")
+    __lt__ = _bin_op("lt")
+    __le__ = _bin_op("leq")
+    __ge__ = _bin_op("geq")
+    __gt__ = _bin_op("gt")
 
     # `and`, `or`, `not` cannot be overloaded in Python,
     # so use bitwise operators as boolean operators
-    __and__ = _bin_op('&&')
-    __or__ = _bin_op('||')
-    __invert__ = _unary_op('unary_!')
-    __rand__ = _bin_op("&&")
-    __ror__ = _bin_op("||")
+    __and__ = _bin_op('and')
+    __or__ = _bin_op('or')
+    __invert__ = _dsl_op('not')
+    __rand__ = _bin_op("and")
+    __ror__ = _bin_op("or")
 
     # container operators
     __contains__ = _bin_op("contains")
@@ -2582,24 +2529,20 @@ class Column(DataFrame):
     isNull = _unary_op("isNull", "True if the current expression is null.")
     isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
 
-    # `as` is keyword
     def alias(self, alias):
         """Return a alias for this column
 
-        >>> df.age.As("age2").collect()
-        [Row(age2=2), Row(age2=5)]
         >>> df.age.alias("age2").collect()
         [Row(age2=2), Row(age2=5)]
         """
         return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
-    As = alias
 
     def cast(self, dataType):
         """ Convert the column into type `dataType`
 
-        >>> df.select(df.age.cast("string").As('ages')).collect()
+        >>> df.select(df.age.cast("string").alias('ages')).collect()
         [Row(ages=u'2'), Row(ages=u'5')]
-        >>> df.select(df.age.cast(StringType()).As('ages')).collect()
+        >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
         [Row(ages=u'2'), Row(ages=u'5')]
         """
         if self.sql_ctx is None:
@@ -2626,6 +2569,40 @@ def _aggregate_func(name, doc=""):
     return staticmethod(_)
 
 
+class UserDefinedFunction(object):
+    def __init__(self, func, returnType):
+        self.func = func
+        self.returnType = returnType
+        self._broadcast = None
+        self._judf = self._create_judf()
+
+    def _create_judf(self):
+        f = self.func  # put it in closure `func`
+        func = lambda _, it: imap(lambda x: f(*x), it)
+        ser = AutoBatchedSerializer(PickleSerializer())
+        command = (func, None, ser, ser)
+        sc = SparkContext._active_spark_context
+        pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+        ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+        jdt = ssql_ctx.parseDataType(self.returnType.json())
+        judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+                                                 includes, sc.pythonExec, broadcast_vars,
+                                                 sc._javaAccumulator, jdt)
+        return judf
+
+    def __del__(self):
+        if self._broadcast is not None:
+            self._broadcast.unpersist()
+            self._broadcast = None
+
+    def __call__(self, *cols):
+        sc = SparkContext._active_spark_context
+        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+                                        sc._gateway._gateway_client)
+        jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+        return Column(jc)
+
+
 class Dsl(object):
     """
     A collections of builtin aggregators
@@ -2659,7 +2636,7 @@ class Dsl(object):
         """ Return a new Column for distinct count of (col, *cols)
 
         >>> from pyspark.sql import Dsl
-        >>> df.agg(Dsl.countDistinct(df.age, df.name).As('c')).collect()
+        >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
         [Row(c=2)]
         """
         sc = SparkContext._active_spark_context
@@ -2674,7 +2651,7 @@ class Dsl(object):
         """ Return a new Column for approxiate distinct count of (col, *cols)
 
         >>> from pyspark.sql import Dsl
-        >>> df.agg(Dsl.approxCountDistinct(df.age).As('c')).collect()
+        >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
         [Row(c=2)]
         """
         sc = SparkContext._active_spark_context
@@ -2684,6 +2661,16 @@ class Dsl(object):
             jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
         return Column(jc)
 
+    @staticmethod
+    def udf(f, returnType=StringType()):
+        """Create a user defined function (UDF)
+
+        >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
+        >>> df.select(slen(df.name).alias('slen')).collect()
+        [Row(slen=5), Row(slen=3)]
+        """
+        return UserDefinedFunction(f, returnType)
+
 
 def _test():
     import doctest

http://git-wip-us.apache.org/repos/asf/spark/blob/dc101b0e/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index ddce77d..4c2aead 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -128,7 +128,6 @@ trait Column extends DataFrame {
    */
   def unary_! : Column = exprToColumn(Not(expr))
 
-
   /**
    * Equality test.
    * {{{
@@ -166,7 +165,7 @@ trait Column extends DataFrame {
    *
    *   // Java:
    *   import static org.apache.spark.sql.Dsl.*;
-   *   df.filter( not(col("colA").equalTo(col("colB"))) );
+   *   df.filter( col("colA").notEqual(col("colB")) );
    * }}}
    */
   def !== (other: Any): Column = constructColumn(other) { o =>
@@ -174,6 +173,22 @@ trait Column extends DataFrame {
   }
 
   /**
+   * Inequality test.
+   * {{{
+   *   // Scala:
+   *   df.select( df("colA") !== df("colB") )
+   *   df.select( !(df("colA") === df("colB")) )
+   *
+   *   // Java:
+   *   import static org.apache.spark.sql.Dsl.*;
+   *   df.filter( col("colA").notEqual(col("colB")) );
+   * }}}
+   */
+  def notEqual(other: Any): Column = constructColumn(other) { o =>
+    Not(EqualTo(expr, o.expr))
+  }
+
+  /**
    * Greater than.
    * {{{
    *   // Scala: The following selects people older than 21.

http://git-wip-us.apache.org/repos/asf/spark/blob/dc101b0e/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index 8d7c2a1..c60d407 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -17,7 +17,13 @@
 
 package org.apache.spark.sql
 
+import java.util.{List => JList, Map => JMap}
+
+import org.apache.spark.Accumulator
+import org.apache.spark.api.python.PythonBroadcast
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.sql.catalyst.expressions.ScalaUdf
+import org.apache.spark.sql.execution.PythonUDF
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -37,3 +43,24 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
     Column(ScalaUdf(f, dataType, exprs.map(_.expr)))
   }
 }
+
+/**
+ * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]].
+ * This is used by Python API.
+ */
+private[sql] case class UserDefinedPythonFunction(
+    name: String,
+    command: Array[Byte],
+    envVars: JMap[String, String],
+    pythonIncludes: JList[String],
+    pythonExec: String,
+    broadcastVars: JList[Broadcast[PythonBroadcast]],
+    accumulator: Accumulator[JList[Array[Byte]]],
+    dataType: DataType) {
+
+  def apply(exprs: Column*): Column = {
+    val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
+      accumulator, dataType, exprs.map(_.expr))
+    Column(udf)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org