You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2015/04/17 02:34:00 UTC

spark git commit: [SPARK-6911] [SQL] improve accessor for nested types

Repository: spark
Updated Branches:
  refs/heads/master 5fe434335 -> 6183b5e2c


[SPARK-6911] [SQL] improve accessor for nested types

Support access columns by index in Python:
```
>>> df[df[0] > 3].collect()
[Row(age=5, name=u'Bob')]
```

Access items in ArrayType or MapType
```
>>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
>>> df.select(df.l[0], df.d["key"]).show()
```

Access field in StructType
```
>>> df.select(df.r.getField("b")).show()
>>> df.select(df.r.a).show()
```

Author: Davies Liu <da...@databricks.com>

Closes #5513 from davies/access and squashes the following commits:

e04d5a0 [Davies Liu] Update run-tests-jenkins
7ada9eb [Davies Liu] update timeout
d125ac4 [Davies Liu] check column name, improve scala tests
6b62540 [Davies Liu] fix test
db15b42 [Davies Liu] Merge branch 'master' of github.com:apache/spark into access
6c32e79 [Davies Liu] add scala tests
11f1df3 [Davies Liu] improve accessor for nested types


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6183b5e2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6183b5e2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6183b5e2

Branch: refs/heads/master
Commit: 6183b5e2caedd074073d0f6cb6609a634e2f5194
Parents: 5fe4343
Author: Davies Liu <da...@databricks.com>
Authored: Thu Apr 16 17:33:57 2015 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Thu Apr 16 17:33:57 2015 -0700

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 49 ++++++++++++++++++--
 python/pyspark/sql/tests.py                     | 18 +++++++
 .../scala/org/apache/spark/sql/Column.scala     |  7 +--
 .../org/apache/spark/sql/DataFrameSuite.scala   |  6 +++
 .../scala/org/apache/spark/sql/TestData.scala   |  9 ++--
 5 files changed, 76 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6183b5e2/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d76504f..b9a3e6c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -563,16 +563,23 @@ class DataFrame(object):
         [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
         >>> df[ df.age > 3 ].collect()
         [Row(age=5, name=u'Bob')]
+        >>> df[df[0] > 3].collect()
+        [Row(age=5, name=u'Bob')]
         """
         if isinstance(item, basestring):
+            if item not in self.columns:
+                raise IndexError("no such column: %s" % item)
             jc = self._jdf.apply(item)
             return Column(jc)
         elif isinstance(item, Column):
             return self.filter(item)
-        elif isinstance(item, list):
+        elif isinstance(item, (list, tuple)):
             return self.select(*item)
+        elif isinstance(item, int):
+            jc = self._jdf.apply(self.columns[item])
+            return Column(jc)
         else:
-            raise IndexError("unexpected index: %s" % item)
+            raise TypeError("unexpected type: %s" % type(item))
 
     def __getattr__(self, name):
         """Returns the :class:`Column` denoted by ``name``.
@@ -580,8 +587,8 @@ class DataFrame(object):
         >>> df.select(df.age).collect()
         [Row(age=2), Row(age=5)]
         """
-        if name.startswith("__"):
-            raise AttributeError(name)
+        if name not in self.columns:
+            raise AttributeError("No such column: %s" % name)
         jc = self._jdf.apply(name)
         return Column(jc)
 
@@ -1093,7 +1100,39 @@ class Column(object):
     # container operators
     __contains__ = _bin_op("contains")
     __getitem__ = _bin_op("getItem")
-    getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
+
+    def getItem(self, key):
+        """An expression that gets an item at position `ordinal` out of a list,
+         or gets an item by key out of a dict.
+
+        >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
+        >>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
+        l[0] d[key]
+        1    value
+        >>> df.select(df.l[0], df.d["key"]).show()
+        l[0] d[key]
+        1    value
+        """
+        return self[key]
+
+    def getField(self, name):
+        """An expression that gets a field by name in a StructField.
+
+        >>> from pyspark.sql import Row
+        >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
+        >>> df.select(df.r.getField("b")).show()
+        r.b
+        b
+        >>> df.select(df.r.a).show()
+        r.a
+        1
+        """
+        return Column(self._jc.getField(name))
+
+    def __getattr__(self, item):
+        if item.startswith("__"):
+            raise AttributeError(item)
+        return self.getField(item)
 
     # string methods
     rlike = _bin_op("rlike")

http://git-wip-us.apache.org/repos/asf/spark/blob/6183b5e2/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7c09a0c..6691e8c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -426,6 +426,24 @@ class SQLTests(ReusedPySparkTestCase):
         pydoc.render_doc(df.foo)
         pydoc.render_doc(df.take(1))
 
+    def test_access_column(self):
+        df = self.df
+        self.assertTrue(isinstance(df.key, Column))
+        self.assertTrue(isinstance(df['key'], Column))
+        self.assertTrue(isinstance(df[0], Column))
+        self.assertRaises(IndexError, lambda: df[2])
+        self.assertRaises(IndexError, lambda: df["bad_key"])
+        self.assertRaises(TypeError, lambda: df[{}])
+
+    def test_access_nested_types(self):
+        df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
+        self.assertEqual(1, df.select(df.l[0]).first()[0])
+        self.assertEqual(1, df.select(df.l.getItem(0)).first()[0])
+        self.assertEqual(1, df.select(df.r.a).first()[0])
+        self.assertEqual("b", df.select(df.r.getField("b")).first()[0])
+        self.assertEqual("v", df.select(df.d["k"]).first()[0])
+        self.assertEqual("v", df.select(df.d.getItem("k")).first()[0])
+
     def test_infer_long_type(self):
         longrow = [Row(f1='a', f2=100000000000000)]
         df = self.sc.parallelize(longrow).toDF()

http://git-wip-us.apache.org/repos/asf/spark/blob/6183b5e2/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 3cd7adf..edb229c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -515,14 +515,15 @@ class Column(protected[sql] val expr: Expression) extends Logging {
   def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
 
   /**
-   * An expression that gets an item at position `ordinal` out of an array.
+   * An expression that gets an item at position `ordinal` out of an array,
+   * or gets a value by key `key` in a [[MapType]].
    *
    * @group expr_ops
    */
-  def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
+  def getItem(key: Any): Column = GetItem(expr, Literal(key))
 
   /**
-   * An expression that gets a field by name in a [[StructField]].
+   * An expression that gets a field by name in a [[StructType]].
    *
    * @group expr_ops
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/6183b5e2/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b26e22f..34b2cb0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -86,6 +86,12 @@ class DataFrameSuite extends QueryTest {
     TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
   }
 
+  test("access complex data") {
+    assert(complexData.filter(complexData("a").getItem(0) === 2).count() == 1)
+    assert(complexData.filter(complexData("m").getItem("1") === 1).count() == 1)
+    assert(complexData.filter(complexData("s").getField("key") === 1).count() == 1)
+  }
+
   test("table scan") {
     checkAnswer(
       testData,

http://git-wip-us.apache.org/repos/asf/spark/blob/6183b5e2/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 637f59b..225b51b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,9 +20,8 @@ package org.apache.spark.sql
 import java.sql.Timestamp
 
 import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test._
 import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.test._
 
 
 case class TestData(key: Int, value: String)
@@ -199,11 +198,11 @@ object TestData {
     Salary(1, 1000.0) :: Nil).toDF()
   salary.registerTempTable("salary")
 
-  case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
+  case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
   val complexData =
     TestSQLContext.sparkContext.parallelize(
-      ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
-        :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
+      ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1), true)
+        :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2), false)
         :: Nil).toDF()
   complexData.registerTempTable("complexData")
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org