You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/02/21 00:35:17 UTC

spark git commit: [SPARK-5898] [SPARK-5896] [SQL] [PySpark] create DataFrame from pandas and tuple/list

Repository: spark
Updated Branches:
  refs/heads/master 4a17eedb1 -> 5b0a42cb1


[SPARK-5898] [SPARK-5896] [SQL]  [PySpark] create DataFrame from pandas and tuple/list

Fix createDataFrame() from pandas DataFrame (not tested by jenkins, depends on SPARK-5693).

It also support to create DataFrame from plain tuple/list without column names, `_1`, `_2` will be used as column names.

Author: Davies Liu <da...@databricks.com>

Closes #4679 from davies/pandas and squashes the following commits:

c0cbe0b [Davies Liu] fix tests
8466d1d [Davies Liu] fix create DataFrame from pandas


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5b0a42cb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5b0a42cb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5b0a42cb

Branch: refs/heads/master
Commit: 5b0a42cb17b840c82d3f8a5ad061d99e261ceadf
Parents: 4a17eed
Author: Davies Liu <da...@databricks.com>
Authored: Fri Feb 20 15:35:05 2015 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Fri Feb 20 15:35:05 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/context.py | 12 ++++++++++--
 python/pyspark/sql/tests.py   |  2 +-
 python/pyspark/sql/types.py   | 26 +++++++++-----------------
 3 files changed, 20 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5b0a42cb/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 3f168f7..313f15e 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -351,6 +351,8 @@ class SQLContext(object):
         :return: a DataFrame
 
         >>> l = [('Alice', 1)]
+        >>> sqlCtx.createDataFrame(l).collect()
+        [Row(_1=u'Alice', _2=1)]
         >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
         [Row(name=u'Alice', age=1)]
 
@@ -359,6 +361,8 @@ class SQLContext(object):
         [Row(age=1, name=u'Alice')]
 
         >>> rdd = sc.parallelize(l)
+        >>> sqlCtx.createDataFrame(rdd).collect()
+        [Row(_1=u'Alice', _2=1)]
         >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
         >>> df.collect()
         [Row(name=u'Alice', age=1)]
@@ -377,14 +381,17 @@ class SQLContext(object):
         >>> df3 = sqlCtx.createDataFrame(rdd, schema)
         >>> df3.collect()
         [Row(name=u'Alice', age=1)]
+
+        >>> sqlCtx.createDataFrame(df.toPandas()).collect()  # doctest: +SKIP
+        [Row(name=u'Alice', age=1)]
         """
         if isinstance(data, DataFrame):
             raise TypeError("data is already a DataFrame")
 
         if has_pandas and isinstance(data, pandas.DataFrame):
-            data = self._sc.parallelize(data.to_records(index=False))
             if schema is None:
                 schema = list(data.columns)
+            data = [r.tolist() for r in data.to_records(index=False)]
 
         if not isinstance(data, RDD):
             try:
@@ -399,7 +406,8 @@ class SQLContext(object):
         if isinstance(schema, (list, tuple)):
             first = data.first()
             if not isinstance(first, (list, tuple)):
-                raise ValueError("each row in `rdd` should be list or tuple")
+                raise ValueError("each row in `rdd` should be list or tuple, "
+                                 "but got %r" % type(first))
             row_cls = Row(*schema)
             schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5b0a42cb/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 8e1bb36..39071e7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -186,7 +186,7 @@ class SQLTests(ReusedPySparkTestCase):
         self.assertEqual("2", row.d)
 
     def test_infer_schema(self):
-        d = [Row(l=[], d={}),
+        d = [Row(l=[], d={}, s=None),
              Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
         rdd = self.sc.parallelize(d)
         df = self.sqlCtx.createDataFrame(rdd)

http://git-wip-us.apache.org/repos/asf/spark/blob/5b0a42cb/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 9409c6f..b6e41cf 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -604,7 +604,7 @@ def _infer_type(obj):
     ExamplePointUDT
     """
     if obj is None:
-        raise ValueError("Can not infer type for None")
+        return NullType()
 
     if hasattr(obj, '__UDT__'):
         return obj.__UDT__
@@ -637,15 +637,14 @@ def _infer_schema(row):
     if isinstance(row, dict):
         items = sorted(row.items())
 
-    elif isinstance(row, tuple):
+    elif isinstance(row, (tuple, list)):
         if hasattr(row, "_fields"):  # namedtuple
             items = zip(row._fields, tuple(row))
         elif hasattr(row, "__FIELDS__"):  # Row
             items = zip(row.__FIELDS__, tuple(row))
-        elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
-            items = row
         else:
-            raise ValueError("Can't infer schema from tuple")
+            names = ['_%d' % i for i in range(1, len(row) + 1)]
+            items = zip(names, row)
 
     elif hasattr(row, "__dict__"):  # object
         items = sorted(row.__dict__.items())
@@ -812,17 +811,10 @@ def _create_converter(dataType):
         if obj is None:
             return
 
-        if isinstance(obj, tuple):
-            if hasattr(obj, "_fields"):
-                d = dict(zip(obj._fields, obj))
-            elif hasattr(obj, "__FIELDS__"):
-                d = dict(zip(obj.__FIELDS__, obj))
-            elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
-                d = dict(obj)
-            else:
-                raise ValueError("unexpected tuple: %s" % str(obj))
+        if isinstance(obj, (tuple, list)):
+            return tuple(conv(v) for v, conv in zip(obj, converters))
 
-        elif isinstance(obj, dict):
+        if isinstance(obj, dict):
             d = obj
         elif hasattr(obj, "__dict__"):  # object
             d = obj.__dict__
@@ -1022,7 +1014,7 @@ def _verify_type(obj, dataType):
         return
 
     _type = type(dataType)
-    assert _type in _acceptable_types, "unkown datatype: %s" % dataType
+    assert _type in _acceptable_types, "unknown datatype: %s" % dataType
 
     # subclass of them can not be deserialized in JVM
     if type(obj) not in _acceptable_types[_type]:
@@ -1040,7 +1032,7 @@ def _verify_type(obj, dataType):
 
     elif isinstance(dataType, StructType):
         if len(obj) != len(dataType.fields):
-            raise ValueError("Length of object (%d) does not match with"
+            raise ValueError("Length of object (%d) does not match with "
                              "length of fields (%d)" % (len(obj), len(dataType.fields)))
         for v, f in zip(obj, dataType.fields):
             _verify_type(v, f.dataType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org