You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/07/02 19:03:49 UTC

git commit: SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG

Repository: spark
Updated Branches:
  refs/heads/master 6596392da -> 5c6ec94da


SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG

**Description** This patch enables using the `.select()` function in SchemaRDD with functions such as `Sum`, `Count` and other.
**Testing** Unit tests added.

Author: Ximo Guanter Gonzalbez <xi...@tid.es>

Closes #1211 from edrevo/add-expression-support-in-select and squashes the following commits:

fe4a1e1 [Ximo Guanter Gonzalbez] Extend SQL DSL to functions
e1d344a [Ximo Guanter Gonzalbez] SPARK-2186: Spark SQL DSL support for simple aggregations such as SUM and AVG


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5c6ec94d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5c6ec94d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5c6ec94d

Branch: refs/heads/master
Commit: 5c6ec94da1bacd8e65a43acb92b6721493484e7b
Parents: 6596392
Author: Ximo Guanter Gonzalbez <xi...@tid.es>
Authored: Wed Jul 2 10:03:44 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Wed Jul 2 10:03:44 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/dsl/package.scala | 11 +++++++
 .../scala/org/apache/spark/sql/SchemaRDD.scala  |  9 ++++--
 .../org/apache/spark/sql/DslQuerySuite.scala    | 32 ++++++++++++++++----
 3 files changed, 44 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5c6ec94d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 26ad483..1b503b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -108,6 +108,17 @@ package object dsl {
 
     implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name)
 
+    def sum(e: Expression) = Sum(e)
+    def sumDistinct(e: Expression) = SumDistinct(e)
+    def count(e: Expression) = Count(e)
+    def countDistinct(e: Expression*) = CountDistinct(e)
+    def avg(e: Expression) = Average(e)
+    def first(e: Expression) = First(e)
+    def min(e: Expression) = Min(e)
+    def max(e: Expression) = Max(e)
+    def upper(e: Expression) = Upper(e)
+    def lower(e: Expression) = Lower(e)
+
     implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
     // TODO more implicit class for literal?
     implicit class DslString(val s: String) extends ImplicitOperators {

http://git-wip-us.apache.org/repos/asf/spark/blob/5c6ec94d/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 7c0efb4..8f9f54f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -133,8 +133,13 @@ class SchemaRDD(
    *
    * @group Query
    */
-  def select(exprs: NamedExpression*): SchemaRDD =
-    new SchemaRDD(sqlContext, Project(exprs, logicalPlan))
+  def select(exprs: Expression*): SchemaRDD = {
+    val aliases = exprs.zipWithIndex.map {
+      case (ne: NamedExpression, _) => ne
+      case (e, i) => Alias(e, s"c$i")()
+    }
+    new SchemaRDD(sqlContext, Project(aliases, logicalPlan))
+  }
 
   /**
    * Filters the output, only returning those rows where `condition` evaluates to true.

http://git-wip-us.apache.org/repos/asf/spark/blob/5c6ec94d/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index e4a64a7..04ac008 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -60,6 +60,26 @@ class DslQuerySuite extends QueryTest {
       Seq(Seq("1")))
   }
 
+  test("select with functions") {
+    checkAnswer(
+      testData.select(sum('value), avg('value), count(1)),
+      Seq(Seq(5050.0, 50.5, 100)))
+
+    checkAnswer(
+      testData2.select('a + 'b, 'a < 'b),
+      Seq(
+        Seq(2, false),
+        Seq(3, true),
+        Seq(3, false),
+        Seq(4, false),
+        Seq(4, false),
+        Seq(5, false)))
+
+    checkAnswer(
+      testData2.select(sumDistinct('a)),
+      Seq(Seq(6)))
+  }
+
   test("sorting") {
     checkAnswer(
       testData2.orderBy('a.asc, 'b.asc),
@@ -110,17 +130,17 @@ class DslQuerySuite extends QueryTest {
 
   test("average") {
     checkAnswer(
-      testData2.groupBy()(Average('a)),
+      testData2.groupBy()(avg('a)),
       2.0)
   }
 
   test("null average") {
     checkAnswer(
-      testData3.groupBy()(Average('b)),
+      testData3.groupBy()(avg('b)),
       2.0)
 
     checkAnswer(
-      testData3.groupBy()(Average('b), CountDistinct('b :: Nil)),
+      testData3.groupBy()(avg('b), countDistinct('b)),
       (2.0, 1) :: Nil)
   }
 
@@ -130,17 +150,17 @@ class DslQuerySuite extends QueryTest {
 
   test("null count") {
     checkAnswer(
-      testData3.groupBy('a)('a, Count('b)),
+      testData3.groupBy('a)('a, count('b)),
       Seq((1,0), (2, 1))
     )
 
     checkAnswer(
-      testData3.groupBy('a)('a, Count('a + 'b)),
+      testData3.groupBy('a)('a, count('a + 'b)),
       Seq((1,0), (2, 1))
     )
 
     checkAnswer(
-      testData3.groupBy()(Count('a), Count('b), Count(1), CountDistinct('a :: Nil), CountDistinct('b :: Nil)),
+      testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)),
       (2, 1, 2, 2, 1) :: Nil
     )
   }