You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/09/20 00:33:46 UTC

git commit: [SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row

Repository: spark
Updated Branches:
  refs/heads/master 5522151eb -> a95ad99e3


[SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row

Fix the issue when applySchema() to an RDD of Row.

Also add type mapping for BinaryType.

Author: Davies Liu <da...@gmail.com>

Closes #2448 from davies/row and squashes the following commits:

dd220cf [Davies Liu] fix test
3f3f188 [Davies Liu] add more test
f559746 [Davies Liu] add tests, fix serialization
9688fd2 [Davies Liu] support applySchema to RDD of Row


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a95ad99e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a95ad99e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a95ad99e

Branch: refs/heads/master
Commit: a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201
Parents: 5522151
Author: Davies Liu <da...@gmail.com>
Authored: Fri Sep 19 15:33:42 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Sep 19 15:33:42 2014 -0700

----------------------------------------------------------------------
 python/pyspark/sql.py   | 13 ++++++++++---
 python/pyspark/tests.py | 11 ++++++++++-
 2 files changed, 20 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a95ad99e/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 42a9920..653195e 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -440,6 +440,7 @@ _type_mappings = {
     float: DoubleType,
     str: StringType,
     unicode: StringType,
+    bytearray: BinaryType,
     decimal.Decimal: DecimalType,
     datetime.datetime: TimestampType,
     datetime.date: TimestampType,
@@ -690,11 +691,12 @@ _acceptable_types = {
     ByteType: (int, long),
     ShortType: (int, long),
     IntegerType: (int, long),
-    LongType: (long,),
+    LongType: (int, long),
     FloatType: (float,),
     DoubleType: (float,),
     DecimalType: (decimal.Decimal,),
     StringType: (str, unicode),
+    BinaryType: (bytearray,),
     TimestampType: (datetime.datetime,),
     ArrayType: (list, tuple, array),
     MapType: (dict,),
@@ -728,9 +730,9 @@ def _verify_type(obj, dataType):
         return
 
     _type = type(dataType)
-    if _type not in _acceptable_types:
-        return
+    assert _type in _acceptable_types, "unkown datatype: %s" % dataType
 
+    # subclass of them can not be deserialized in JVM
     if type(obj) not in _acceptable_types[_type]:
         raise TypeError("%s can not accept abject in type %s"
                         % (dataType, type(obj)))
@@ -1121,6 +1123,11 @@ class SQLContext(object):
 
         # take the first few rows to verify schema
         rows = rdd.take(10)
+        # Row() cannot been deserialized by Pyrolite
+        if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
+            rdd = rdd.map(tuple)
+            rows = rdd.take(10)
+
         for row in rows:
             _verify_type(row, schema)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a95ad99e/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7301966..a94eb0f 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -45,7 +45,7 @@ from pyspark.files import SparkFiles
 from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
     CloudPickleSerializer
 from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType
+from pyspark.sql import SQLContext, IntegerType, Row
 from pyspark import shuffle
 
 _have_scipy = False
@@ -659,6 +659,15 @@ class TestSQL(PySparkTestCase):
         self.assertEquals(result.getNumPartitions(), 5)
         self.assertEquals(result.count(), 3)
 
+    def test_apply_schema_to_row(self):
+        srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+        srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
+        self.assertEqual(srdd.collect(), srdd2.collect())
+
+        rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+        srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
+        self.assertEqual(10, srdd3.count())
+
 
 class TestIO(PySparkTestCase):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org