You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/12/17 06:23:33 UTC

spark git commit: [SPARK-4866] support StructType as key in MapType

Repository: spark
Updated Branches:
  refs/heads/master 770d8153a -> ec5c4279e


[SPARK-4866] support StructType as key in MapType

This PR brings support of using StructType(and other hashable types) as key in MapType.

Author: Davies Liu <da...@databricks.com>

Closes #3714 from davies/fix_struct_in_map and squashes the following commits:

68585d7 [Davies Liu] fix primitive types in MapType
9601534 [Davies Liu] support StructType as key in MapType


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec5c4279
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec5c4279
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec5c4279

Branch: refs/heads/master
Commit: ec5c4279edabd5ea2b187aff6662ac07ed825b08
Parents: 770d815
Author: Davies Liu <da...@databricks.com>
Authored: Tue Dec 16 21:23:28 2014 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Dec 16 21:23:28 2014 -0800

----------------------------------------------------------------------
 python/pyspark/sql.py                              | 17 ++++++++++-------
 python/pyspark/tests.py                            |  8 ++++++++
 .../apache/spark/sql/execution/pythonUdfs.scala    |  2 +-
 3 files changed, 19 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index ae28847..1ee0b28 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -788,8 +788,9 @@ def _create_converter(dataType):
         return lambda row: map(conv, row)
 
     elif isinstance(dataType, MapType):
-        conv = _create_converter(dataType.valueType)
-        return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
+        kconv = _create_converter(dataType.keyType)
+        vconv = _create_converter(dataType.valueType)
+        return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
 
     elif isinstance(dataType, NullType):
         return lambda x: None
@@ -944,7 +945,7 @@ def _infer_schema_type(obj, dataType):
 
     elif isinstance(dataType, MapType):
         k, v = obj.iteritems().next()
-        return MapType(_infer_type(k),
+        return MapType(_infer_schema_type(k, dataType.keyType),
                        _infer_schema_type(v, dataType.valueType))
 
     elif isinstance(dataType, StructType):
@@ -1085,7 +1086,7 @@ def _has_struct_or_date(dt):
     elif isinstance(dt, ArrayType):
         return _has_struct_or_date(dt.elementType)
     elif isinstance(dt, MapType):
-        return _has_struct_or_date(dt.valueType)
+        return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
     elif isinstance(dt, DateType):
         return True
     elif isinstance(dt, UserDefinedType):
@@ -1148,12 +1149,13 @@ def _create_cls(dataType):
         return List
 
     elif isinstance(dataType, MapType):
-        cls = _create_cls(dataType.valueType)
+        kcls = _create_cls(dataType.keyType)
+        vcls = _create_cls(dataType.valueType)
 
         def Dict(d):
             if d is None:
                 return
-            return dict((k, _create_object(cls, v)) for k, v in d.items())
+            return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
 
         return Dict
 
@@ -1164,7 +1166,8 @@ def _create_cls(dataType):
         return lambda datum: dataType.deserialize(datum)
 
     elif not isinstance(dataType, StructType):
-        raise Exception("unexpected data type: %s" % dataType)
+        # no wrapper for primitive types
+        return lambda x: x
 
     class Row(tuple):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bca52a7..b474fcf 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -923,6 +923,14 @@ class SQLTests(ReusedPySparkTestCase):
         result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
         self.assertEqual(1, result.first()[0])
 
+    def test_struct_in_map(self):
+        d = [Row(m={Row(i=1): Row(s="")})]
+        rdd = self.sc.parallelize(d)
+        srdd = self.sqlCtx.inferSchema(rdd)
+        k, v = srdd.first().m.items()[0]
+        self.assertEqual(1, k.i)
+        self.assertEqual("", v.s)
+
     def test_convert_row_to_dict(self):
         row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
         self.assertEqual(1, row.asDict()['l'][0].a)

http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 2b4a88d..5a41399 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -132,7 +132,7 @@ object EvaluatePython {
       arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
 
     case (obj: Map[_, _], mt: MapType) => obj.map {
-      case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
+      case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
     }.asJava
 
     case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org