You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/12/19 05:20:14 UTC

spark git commit: [SPARK-2554][SQL] Supporting SumDistinct partial aggregation

Repository: spark
Updated Branches:
  refs/heads/master e7de7e5f4 -> 7687415c2


[SPARK-2554][SQL] Supporting SumDistinct partial aggregation

Adding support to the partial aggregation of SumDistinct

Author: ravipesala <ra...@huawei.com>

Closes #3348 from ravipesala/SPARK-2554 and squashes the following commits:

fd28e4d [ravipesala] Fixed review comments
e60e67f [ravipesala] Fixed test cases and made it as nullable
32fe234 [ravipesala] Supporting SumDistinct partial aggregation Conflicts: 	sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7687415c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7687415c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7687415c

Branch: refs/heads/master
Commit: 7687415c2578b5bdc79c9646c246e52da9a4dd4a
Parents: e7de7e5
Author: ravipesala <ra...@huawei.com>
Authored: Thu Dec 18 20:19:10 2014 -0800
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Thu Dec 18 20:19:10 2014 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/aggregates.scala   | 53 ++++++++++++++++++--
 .../sql/hive/execution/SQLQuerySuite.scala      | 13 +++--
 2 files changed, 58 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7687415c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 0cd9086..5ea9868 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -361,10 +361,10 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
 }
 
 case class SumDistinct(child: Expression)
-  extends AggregateExpression with trees.UnaryNode[Expression] {
+  extends PartialAggregate with trees.UnaryNode[Expression] {
 
+  def this() = this(null)
   override def nullable = true
-
   override def dataType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType(precision + 10, scale)  // Add 10 digits left of decimal point, like Hive
@@ -373,10 +373,55 @@ case class SumDistinct(child: Expression)
     case _ =>
       child.dataType
   }
+  override def toString = s"SUM(DISTINCT ${child})"
+  override def newInstance() = new SumDistinctFunction(child, this)
+
+  override def asPartial = {
+    val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")()
+    SplitEvaluation(
+      CombineSetsAndSum(partialSet.toAttribute, this),
+      partialSet :: Nil)
+  }
+}
 
-  override def toString = s"SUM(DISTINCT $child)"
+case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
+  def this() = this(null, null)
 
-  override def newInstance() = new SumDistinctFunction(child, this)
+  override def children = inputSet :: Nil
+  override def nullable = true
+  override def dataType = base.dataType
+  override def toString = s"CombineAndSum($inputSet)"
+  override def newInstance() = new CombineSetsAndSumFunction(inputSet, this)
+}
+
+case class CombineSetsAndSumFunction(
+    @transient inputSet: Expression,
+    @transient base: AggregateExpression)
+  extends AggregateFunction {
+
+  def this() = this(null, null) // Required for serialization.
+
+  val seen = new OpenHashSet[Any]()
+
+  override def update(input: Row): Unit = {
+    val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]]
+    val inputIterator = inputSetEval.iterator
+    while (inputIterator.hasNext) {
+      seen.add(inputIterator.next)
+    }
+  }
+
+  override def eval(input: Row): Any = {
+    val casted = seen.asInstanceOf[OpenHashSet[Row]]
+    if (casted.size == 0) {
+      null
+    } else {
+      Cast(Literal(
+        casted.iterator.map(f => f.apply(0)).reduceLeft(
+          base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
+        base.dataType).eval(null)
+    }
+  }
 }
 
 case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {

http://git-wip-us.apache.org/repos/asf/spark/blob/7687415c/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 96f3430..f57f31a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -185,9 +185,14 @@ class SQLQuerySuite extends QueryTest {
       sql("SELECT case when ~1=-2 then 1 else 0 end FROM src"),
       sql("SELECT 1 FROM src").collect().toSeq)
   }
-  
- test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") {
-    checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"), 
-        sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq)
+
+  test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") {
+    checkAnswer(sql("SELECT key FROM src WHERE key not between 0 and 10 order by key"),
+      sql("SELECT key FROM src WHERE key between 11 and 500 order by key").collect().toSeq)
+  }
+
+  test("SPARK-2554 SumDistinct partial aggregation") {
+    checkAnswer(sql("SELECT sum( distinct key) FROM src group by key order by key"),
+      sql("SELECT distinct key FROM src order by key").collect().toSeq)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org