You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/12/17 06:23:33 UTC
spark git commit: [SPARK-4866] support StructType as key in MapType
Repository: spark
Updated Branches:
refs/heads/master 770d8153a -> ec5c4279e
[SPARK-4866] support StructType as key in MapType
This PR brings support of using StructType(and other hashable types) as key in MapType.
Author: Davies Liu <da...@databricks.com>
Closes #3714 from davies/fix_struct_in_map and squashes the following commits:
68585d7 [Davies Liu] fix primitive types in MapType
9601534 [Davies Liu] support StructType as key in MapType
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ec5c4279
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ec5c4279
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ec5c4279
Branch: refs/heads/master
Commit: ec5c4279edabd5ea2b187aff6662ac07ed825b08
Parents: 770d815
Author: Davies Liu <da...@databricks.com>
Authored: Tue Dec 16 21:23:28 2014 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Tue Dec 16 21:23:28 2014 -0800
----------------------------------------------------------------------
python/pyspark/sql.py | 17 ++++++++++-------
python/pyspark/tests.py | 8 ++++++++
.../apache/spark/sql/execution/pythonUdfs.scala | 2 +-
3 files changed, 19 insertions(+), 8 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index ae28847..1ee0b28 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -788,8 +788,9 @@ def _create_converter(dataType):
return lambda row: map(conv, row)
elif isinstance(dataType, MapType):
- conv = _create_converter(dataType.valueType)
- return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
+ kconv = _create_converter(dataType.keyType)
+ vconv = _create_converter(dataType.valueType)
+ return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
elif isinstance(dataType, NullType):
return lambda x: None
@@ -944,7 +945,7 @@ def _infer_schema_type(obj, dataType):
elif isinstance(dataType, MapType):
k, v = obj.iteritems().next()
- return MapType(_infer_type(k),
+ return MapType(_infer_schema_type(k, dataType.keyType),
_infer_schema_type(v, dataType.valueType))
elif isinstance(dataType, StructType):
@@ -1085,7 +1086,7 @@ def _has_struct_or_date(dt):
elif isinstance(dt, ArrayType):
return _has_struct_or_date(dt.elementType)
elif isinstance(dt, MapType):
- return _has_struct_or_date(dt.valueType)
+ return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
elif isinstance(dt, UserDefinedType):
@@ -1148,12 +1149,13 @@ def _create_cls(dataType):
return List
elif isinstance(dataType, MapType):
- cls = _create_cls(dataType.valueType)
+ kcls = _create_cls(dataType.keyType)
+ vcls = _create_cls(dataType.valueType)
def Dict(d):
if d is None:
return
- return dict((k, _create_object(cls, v)) for k, v in d.items())
+ return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
return Dict
@@ -1164,7 +1166,8 @@ def _create_cls(dataType):
return lambda datum: dataType.deserialize(datum)
elif not isinstance(dataType, StructType):
- raise Exception("unexpected data type: %s" % dataType)
+ # no wrapper for primitive types
+ return lambda x: x
class Row(tuple):
http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index bca52a7..b474fcf 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -923,6 +923,14 @@ class SQLTests(ReusedPySparkTestCase):
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.first()[0])
+ def test_struct_in_map(self):
+ d = [Row(m={Row(i=1): Row(s="")})]
+ rdd = self.sc.parallelize(d)
+ srdd = self.sqlCtx.inferSchema(rdd)
+ k, v = srdd.first().m.items()[0]
+ self.assertEqual(1, k.i)
+ self.assertEqual("", v.s)
+
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
http://git-wip-us.apache.org/repos/asf/spark/blob/ec5c4279/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 2b4a88d..5a41399 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -132,7 +132,7 @@ object EvaluatePython {
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
case (obj: Map[_, _], mt: MapType) => obj.map {
- case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
+ case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
}.asJava
case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org