You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/04/21 09:08:21 UTC

spark git commit: [SPARK-6949] [SQL] [PySpark] Support Date/Timestamp in Column expression

Repository: spark
Updated Branches:
  refs/heads/master 8136810df -> ab9128fb7


[SPARK-6949] [SQL] [PySpark] Support Date/Timestamp in Column expression

This PR enable auto_convert in JavaGateway, then we could register a converter for a given types, for example, date and datetime.

There are two bugs related to auto_convert, see [1] and [2], we workaround it in this PR.

[1]  https://github.com/bartdag/py4j/issues/160
[2] https://github.com/bartdag/py4j/issues/161

cc rxin JoshRosen

Author: Davies Liu <da...@databricks.com>

Closes #5570 from davies/py4j_date and squashes the following commits:

eb4fa53 [Davies Liu] fix tests in python 3
d17d634 [Davies Liu] rollback changes in mllib
2e7566d [Davies Liu] convert tuple into ArrayList
ceb3779 [Davies Liu] Update rdd.py
3c373f3 [Davies Liu] support date and datetime by auto_convert
cb094ff [Davies Liu] enable auto convert


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab9128fb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab9128fb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab9128fb

Branch: refs/heads/master
Commit: ab9128fb7ec7ca479dc91e7cc7c775e8d071eafa
Parents: 8136810
Author: Davies Liu <da...@databricks.com>
Authored: Tue Apr 21 00:08:18 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Apr 21 00:08:18 2015 -0700

----------------------------------------------------------------------
 python/pyspark/context.py           |  6 +-----
 python/pyspark/java_gateway.py      | 15 ++++++++++++++-
 python/pyspark/rdd.py               |  3 +++
 python/pyspark/sql/_types.py        | 27 +++++++++++++++++++++++++++
 python/pyspark/sql/context.py       | 13 ++++---------
 python/pyspark/sql/dataframe.py     | 18 ++++--------------
 python/pyspark/sql/tests.py         | 11 +++++++++++
 python/pyspark/streaming/context.py | 11 +++--------
 python/pyspark/streaming/kafka.py   |  7 ++-----
 python/pyspark/streaming/tests.py   |  6 +-----
 10 files changed, 70 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6a743ac..b006120 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -23,8 +23,6 @@ import sys
 from threading import Lock
 from tempfile import NamedTemporaryFile
 
-from py4j.java_collections import ListConverter
-
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
@@ -643,7 +641,6 @@ class SparkContext(object):
             rdds = [x._reserialize() for x in rdds]
         first = rdds[0]._jrdd
         rest = [x._jrdd for x in rdds[1:]]
-        rest = ListConverter().convert(rest, self._gateway._gateway_client)
         return RDD(self._jsc.union(first, rest), self, rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
@@ -846,13 +843,12 @@ class SparkContext(object):
         """
         if partitions is None:
             partitions = range(rdd._jrdd.partitions().size())
-        javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
 
         # Implementation note: This is implemented as a mapPartitions followed
         # by runJob() in order to avoid having to pass a Python lambda into
         # SparkContext#runJob.
         mappedRDD = rdd.mapPartitions(partitionFunc)
-        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
+        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions,
                                           allowLocal)
         return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/java_gateway.py
----------------------------------------------------------------------
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 45bc38f..3cee4ea 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -17,17 +17,30 @@
 
 import atexit
 import os
+import sys
 import select
 import signal
 import shlex
 import socket
 import platform
 from subprocess import Popen, PIPE
+
+if sys.version >= '3':
+    xrange = range
+
 from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+from py4j.java_collections import ListConverter
 
 from pyspark.serializers import read_int
 
 
+# patching ListConverter, or it will convert bytearray into Java ArrayList
+def can_convert_list(self, obj):
+    return isinstance(obj, (list, tuple, xrange))
+
+ListConverter.can_convert = can_convert_list
+
+
 def launch_gateway():
     if "PYSPARK_GATEWAY_PORT" in os.environ:
         gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
@@ -92,7 +105,7 @@ def launch_gateway():
             atexit.register(killChild)
 
     # Connect to the gateway
-    gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
+    gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
 
     # Import the classes used by PySpark
     java_import(gateway.jvm, "org.apache.spark.SparkConf")

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d9cdbb6..d254deb 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -2267,6 +2267,9 @@ def _prepare_for_python_RDD(sc, command, obj=None):
         # The broadcast will have same life cycle as created PythonRDD
         broadcast = sc.broadcast(pickled_command)
         pickled_command = ser.dumps(broadcast)
+    # There is a bug in py4j.java_gateway.JavaClass with auto_convert
+    # https://github.com/bartdag/py4j/issues/161
+    # TODO: use auto_convert once py4j fix the bug
     broadcast_vars = ListConverter().convert(
         [x._jbroadcast for x in sc._pickled_broadcast_vars],
         sc._gateway._gateway_client)

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/sql/_types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py
index 110d115..95fb91a 100644
--- a/python/pyspark/sql/_types.py
+++ b/python/pyspark/sql/_types.py
@@ -17,6 +17,7 @@
 
 import sys
 import decimal
+import time
 import datetime
 import keyword
 import warnings
@@ -30,6 +31,9 @@ if sys.version >= "3":
     long = int
     unicode = str
 
+from py4j.protocol import register_input_converter
+from py4j.java_gateway import JavaClass
+
 __all__ = [
     "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
     "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
@@ -1237,6 +1241,29 @@ class Row(tuple):
             return "<Row(%s)>" % ", ".join(self)
 
 
+class DateConverter(object):
+    def can_convert(self, obj):
+        return isinstance(obj, datetime.date)
+
+    def convert(self, obj, gateway_client):
+        Date = JavaClass("java.sql.Date", gateway_client)
+        return Date.valueOf(obj.strftime("%Y-%m-%d"))
+
+
+class DatetimeConverter(object):
+    def can_convert(self, obj):
+        return isinstance(obj, datetime.datetime)
+
+    def convert(self, obj, gateway_client):
+        Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
+        return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
+
+
+# datetime is a subclass of date, we should register DatetimeConverter first
+register_input_converter(DatetimeConverter())
+register_input_converter(DateConverter())
+
+
 def _test():
     import doctest
     from pyspark.context import SparkContext

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index acf3c11..f6f107c 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -25,7 +25,6 @@ else:
     from itertools import imap as map
 
 from py4j.protocol import Py4JError
-from py4j.java_collections import MapConverter
 
 from pyspark.rdd import RDD, _prepare_for_python_RDD, ignore_unicode_prefix
 from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
@@ -442,15 +441,13 @@ class SQLContext(object):
         if source is None:
             source = self.getConf("spark.sql.sources.default",
                                   "org.apache.spark.sql.parquet")
-        joptions = MapConverter().convert(options,
-                                          self._sc._gateway._gateway_client)
         if schema is None:
-            df = self._ssql_ctx.load(source, joptions)
+            df = self._ssql_ctx.load(source, options)
         else:
             if not isinstance(schema, StructType):
                 raise TypeError("schema should be StructType")
             scala_datatype = self._ssql_ctx.parseDataType(schema.json())
-            df = self._ssql_ctx.load(source, scala_datatype, joptions)
+            df = self._ssql_ctx.load(source, scala_datatype, options)
         return DataFrame(df, self)
 
     def createExternalTable(self, tableName, path=None, source=None,
@@ -471,16 +468,14 @@ class SQLContext(object):
         if source is None:
             source = self.getConf("spark.sql.sources.default",
                                   "org.apache.spark.sql.parquet")
-        joptions = MapConverter().convert(options,
-                                          self._sc._gateway._gateway_client)
         if schema is None:
-            df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
+            df = self._ssql_ctx.createExternalTable(tableName, source, options)
         else:
             if not isinstance(schema, StructType):
                 raise TypeError("schema should be StructType")
             scala_datatype = self._ssql_ctx.parseDataType(schema.json())
             df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
-                                                    joptions)
+                                                    options)
         return DataFrame(df, self)
 
     @ignore_unicode_prefix

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 75c181c..ca9bf8e 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -25,8 +25,6 @@ if sys.version >= '3':
 else:
     from itertools import imap as map
 
-from py4j.java_collections import ListConverter, MapConverter
-
 from pyspark.context import SparkContext
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
 from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
@@ -186,9 +184,7 @@ class DataFrame(object):
             source = self.sql_ctx.getConf("spark.sql.sources.default",
                                           "org.apache.spark.sql.parquet")
         jmode = self._java_save_mode(mode)
-        joptions = MapConverter().convert(options,
-                                          self.sql_ctx._sc._gateway._gateway_client)
-        self._jdf.saveAsTable(tableName, source, jmode, joptions)
+        self._jdf.saveAsTable(tableName, source, jmode, options)
 
     def save(self, path=None, source=None, mode="error", **options):
         """Saves the contents of the :class:`DataFrame` to a data source.
@@ -211,9 +207,7 @@ class DataFrame(object):
             source = self.sql_ctx.getConf("spark.sql.sources.default",
                                           "org.apache.spark.sql.parquet")
         jmode = self._java_save_mode(mode)
-        joptions = MapConverter().convert(options,
-                                          self._sc._gateway._gateway_client)
-        self._jdf.save(source, jmode, joptions)
+        self._jdf.save(source, jmode, options)
 
     @property
     def schema(self):
@@ -819,7 +813,6 @@ class DataFrame(object):
             value = float(value)
 
         if isinstance(value, dict):
-            value = MapConverter().convert(value, self.sql_ctx._sc._gateway._gateway_client)
             return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
         elif subset is None:
             return DataFrame(self._jdf.na().fill(value), self.sql_ctx)
@@ -932,9 +925,7 @@ class GroupedData(object):
         """
         assert exprs, "exprs should not be empty"
         if len(exprs) == 1 and isinstance(exprs[0], dict):
-            jmap = MapConverter().convert(exprs[0],
-                                          self.sql_ctx._sc._gateway._gateway_client)
-            jdf = self._jdf.agg(jmap)
+            jdf = self._jdf.agg(exprs[0])
         else:
             # Columns
             assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
@@ -1040,8 +1031,7 @@ def _to_seq(sc, cols, converter=None):
     """
     if converter:
         cols = [converter(c) for c in cols]
-    jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
-    return sc._jvm.PythonUtils.toSeq(jcols)
+    return sc._jvm.PythonUtils.toSeq(cols)
 
 
 def _unary_op(name, doc="unary operator"):

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index aa3aa1d..23e8428 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -26,6 +26,7 @@ import shutil
 import tempfile
 import pickle
 import functools
+import datetime
 
 import py4j
 
@@ -464,6 +465,16 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual(_infer_type(2**61), LongType())
         self.assertEqual(_infer_type(2**71), LongType())
 
+    def test_filter_with_datetime(self):
+        time = datetime.datetime(2015, 4, 17, 23, 1, 2, 3000)
+        date = time.date()
+        row = Row(date=date, time=time)
+        df = self.sqlCtx.createDataFrame([row])
+        self.assertEqual(1, df.filter(df.date == date).count())
+        self.assertEqual(1, df.filter(df.time == time).count())
+        self.assertEqual(0, df.filter(df.date > date).count())
+        self.assertEqual(0, df.filter(df.time > time).count())
+
     def test_dropna(self):
         schema = StructType([
             StructField("name", StringType(), True),

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/streaming/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 4590c58..ac5ba69 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -20,7 +20,6 @@ from __future__ import print_function
 import os
 import sys
 
-from py4j.java_collections import ListConverter
 from py4j.java_gateway import java_import, JavaObject
 
 from pyspark import RDD, SparkConf
@@ -305,9 +304,7 @@ class StreamingContext(object):
             rdds = [self._sc.parallelize(input) for input in rdds]
         self._check_serializers(rdds)
 
-        jrdds = ListConverter().convert([r._jrdd for r in rdds],
-                                        SparkContext._gateway._gateway_client)
-        queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
+        queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
         if default:
             default = default._reserialize(rdds[0]._jrdd_deserializer)
             jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
@@ -322,8 +319,7 @@ class StreamingContext(object):
         the transform function parameter will be the same as the order
         of corresponding DStreams in the list.
         """
-        jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
-                                            SparkContext._gateway._gateway_client)
+        jdstreams = [d._jdstream for d in dstreams]
         # change the final serializer to sc.serializer
         func = TransformFunction(self._sc,
                                  lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
@@ -346,6 +342,5 @@ class StreamingContext(object):
         if len(set(s._slideDuration for s in dstreams)) > 1:
             raise ValueError("All DStreams should have same slide duration")
         first = dstreams[0]
-        jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
-                                        SparkContext._gateway._gateway_client)
+        jrest = [d._jdstream for d in dstreams[1:]]
         return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/streaming/kafka.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 7a7b6e1..8d610d6 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -15,8 +15,7 @@
 # limitations under the License.
 #
 
-from py4j.java_collections import MapConverter
-from py4j.java_gateway import java_import, Py4JError, Py4JJavaError
+from py4j.java_gateway import Py4JJavaError
 
 from pyspark.storagelevel import StorageLevel
 from pyspark.serializers import PairDeserializer, NoOpSerializer
@@ -57,8 +56,6 @@ class KafkaUtils(object):
         })
         if not isinstance(topics, dict):
             raise TypeError("topics should be dict")
-        jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
-        jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
         jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
 
         try:
@@ -66,7 +63,7 @@ class KafkaUtils(object):
             helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
                 .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
             helper = helperClass.newInstance()
-            jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
+            jstream = helper.createStream(ssc._jssc, kafkaParams, topics, jlevel)
         except Py4JJavaError as e:
             # TODO: use --jar once it also work on driver
             if 'ClassNotFoundException' in str(e.java_exception):

http://git-wip-us.apache.org/repos/asf/spark/blob/ab9128fb/python/pyspark/streaming/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index 06d2215..33f958a 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -24,8 +24,6 @@ import tempfile
 import struct
 from functools import reduce
 
-from py4j.java_collections import MapConverter
-
 from pyspark.context import SparkConf, SparkContext, RDD
 from pyspark.streaming.context import StreamingContext
 from pyspark.streaming.kafka import KafkaUtils
@@ -581,11 +579,9 @@ class KafkaStreamTests(PySparkStreamingTestCase):
         """Test the Python Kafka stream API."""
         topic = "topic1"
         sendData = {"a": 3, "b": 5, "c": 10}
-        jSendData = MapConverter().convert(sendData,
-                                           self.ssc.sparkContext._gateway._gateway_client)
 
         self._kafkaTestUtils.createTopic(topic)
-        self._kafkaTestUtils.sendMessages(topic, jSendData)
+        self._kafkaTestUtils.sendMessages(topic, sendData)
 
         stream = KafkaUtils.createStream(self.ssc, self._kafkaTestUtils.zkAddress(),
                                          "test-streaming-consumer", {topic: 1},


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org