You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/01/13 06:44:52 UTC

spark git commit: [SPARK-5138][SQL] Ensure schema can be inferred from a namedtuple

Repository: spark
Updated Branches:
  refs/heads/master 5d9fa5508 -> 1e42e96ec


[SPARK-5138][SQL] Ensure schema can be inferred from a namedtuple

When attempting to infer the schema of an RDD that contains namedtuples, pyspark fails to identify the records as namedtuples, resulting in it raising an error.

Example:

```python
from pyspark import SparkContext
from pyspark.sql import SQLContext
from collections import namedtuple
import os

sc = SparkContext()
rdd = sc.textFile(os.path.join(os.getenv('SPARK_HOME'), 'README.md'))
TextLine = namedtuple('TextLine', 'line length')
tuple_rdd = rdd.map(lambda l: TextLine(line=l, length=len(l)))
tuple_rdd.take(5)  # This works

sqlc = SQLContext(sc)

# The following line raises an error
schema_rdd = sqlc.inferSchema(tuple_rdd)
```

The error raised is:
```
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 107, in main
    process()
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/worker.py", line 98, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/serializers.py", line 227, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/rdd.py", line 1107, in takeUpToNumLeft
    yield next(iterator)
  File "/opt/spark-1.2.0-bin-hadoop2.4/python/pyspark/sql.py", line 816, in convert_struct
    raise ValueError("unexpected tuple: %s" % obj)
TypeError: not all arguments converted during string formatting
```

Author: Gabe Mulley <ga...@edx.org>

Closes #3978 from mulby/inferschema-namedtuple and squashes the following commits:

98c61cc [Gabe Mulley] Ensure exception message is populated correctly
375d96b [Gabe Mulley] Ensure schema can be inferred from a namedtuple


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1e42e96e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1e42e96e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1e42e96e

Branch: refs/heads/master
Commit: 1e42e96ece9e35ceed9ddebef66d589016878b56
Parents: 5d9fa55
Author: Gabe Mulley <ga...@edx.org>
Authored: Mon Jan 12 21:44:51 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Mon Jan 12 21:44:51 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py | 18 ++++++++++++++----
 1 file changed, 14 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1e42e96e/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 0e8b398..014ac17 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -807,14 +807,14 @@ def _create_converter(dataType):
             return
 
         if isinstance(obj, tuple):
-            if hasattr(obj, "fields"):
-                d = dict(zip(obj.fields, obj))
-            if hasattr(obj, "__FIELDS__"):
+            if hasattr(obj, "_fields"):
+                d = dict(zip(obj._fields, obj))
+            elif hasattr(obj, "__FIELDS__"):
                 d = dict(zip(obj.__FIELDS__, obj))
             elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
                 d = dict(obj)
             else:
-                raise ValueError("unexpected tuple: %s" % obj)
+                raise ValueError("unexpected tuple: %s" % str(obj))
 
         elif isinstance(obj, dict):
             d = obj
@@ -1327,6 +1327,16 @@ class SQLContext(object):
         >>> srdd = sqlCtx.inferSchema(nestedRdd2)
         >>> srdd.collect()
         [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
+
+        >>> from collections import namedtuple
+        >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
+        >>> rdd = sc.parallelize(
+        ...     [CustomRow(field1=1, field2="row1"),
+        ...      CustomRow(field1=2, field2="row2"),
+        ...      CustomRow(field1=3, field2="row3")])
+        >>> srdd = sqlCtx.inferSchema(rdd)
+        >>> srdd.collect()[0]
+        Row(field1=1, field2=u'row1')
         """
 
         if isinstance(rdd, SchemaRDD):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org