You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/24 08:40:07 UTC

spark git commit: [SPARK-9294][SQL] cleanup comments, code style, naming typo for the new aggregation

Repository: spark
Updated Branches:
  refs/heads/master d4d762f27 -> 408e64b28


[SPARK-9294][SQL] cleanup comments, code style, naming typo for the new aggregation

fix some comments and code style for https://github.com/apache/spark/pull/7458

Author: Wenchen Fan <cl...@outlook.com>

Closes #7619 from cloud-fan/agg-clean and squashes the following commits:

3925457 [Wenchen Fan] one more...
cc78357 [Wenchen Fan] one more cleanup
26f6a93 [Wenchen Fan] some minor cleanup for the new aggregation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/408e64b2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/408e64b2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/408e64b2

Branch: refs/heads/master
Commit: 408e64b284ef8bd6796d815b5eb603312d090b74
Parents: d4d762f
Author: Wenchen Fan <cl...@outlook.com>
Authored: Thu Jul 23 23:40:01 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jul 23 23:40:01 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  2 +-
 .../expressions/aggregate/interfaces.scala      | 18 ++---
 .../apache/spark/sql/execution/Exchange.scala   |  6 +-
 .../spark/sql/execution/SparkStrategies.scala   |  8 +-
 .../aggregate/sortBasedIterators.scala          | 82 ++++++--------------
 .../spark/sql/execution/aggregate/utils.scala   | 10 +--
 .../org/apache/spark/sql/SQLQuerySuite.scala    |  9 ++-
 7 files changed, 46 insertions(+), 89 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8cadbc5..e916887 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -533,7 +533,7 @@ class Analyzer(
                 case min: Min if isDistinct => min
                 // For other aggregate functions, DISTINCT keyword is not supported for now.
                 // Once we converted to the new code path, we will allow using DISTINCT keyword.
-                case other if isDistinct =>
+                case other: AggregateExpression1 if isDistinct =>
                   failAnalysis(s"$name does not support DISTINCT keyword.")
                 // If it does not have DISTINCT keyword, we will return it as is.
                 case other => other

http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index d3fee1a..10bd19c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCod
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
 
-/** The mode of an [[AggregateFunction1]]. */
+/** The mode of an [[AggregateFunction2]]. */
 private[sql] sealed trait AggregateMode
 
 /**
- * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
+ * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation.
  * This function updates the given aggregation buffer with the original input of this
  * function. When it has processed all input rows, the aggregation buffer is returned.
  */
 private[sql] case object Partial extends AggregateMode
 
 /**
- * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
  * containing intermediate results for this function.
  * This function updates the given aggregation buffer by merging multiple aggregation buffers.
  * When it has processed all input rows, the aggregation buffer is returned.
@@ -42,8 +42,8 @@ private[sql] case object Partial extends AggregateMode
 private[sql] case object PartialMerge extends AggregateMode
 
 /**
- * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
- * containing intermediate results for this function and the generate final result.
+ * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function and then generate final result.
  * This function updates the given aggregation buffer by merging multiple aggregation buffers.
  * When it has processed all input rows, the final result of this function is returned.
  */
@@ -85,12 +85,12 @@ private[sql] case class AggregateExpression2(
   override def nullable: Boolean = aggregateFunction.nullable
 
   override def references: AttributeSet = {
-    val childReferemces = mode match {
+    val childReferences = mode match {
       case Partial | Complete => aggregateFunction.references.toSeq
       case PartialMerge | Final => aggregateFunction.bufferAttributes
     }
 
-    AttributeSet(childReferemces)
+    AttributeSet(childReferences)
   }
 
   override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
@@ -99,10 +99,8 @@ private[sql] case class AggregateExpression2(
 abstract class AggregateFunction2
   extends Expression with ImplicitCastInputTypes {
 
-  self: Product =>
-
   /** An aggregate function is not foldable. */
-  override def foldable: Boolean = false
+  final override def foldable: Boolean = false
 
   /**
    * The offset of this function's buffer in the underlying buffer shared with other functions.

http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index d31e265..41a0c51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -224,13 +224,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
       // compatible.
       // TODO: ASSUMES TRANSITIVITY?
       def compatible: Boolean =
-        !operator.children
+        operator.children
           .map(_.outputPartitioning)
           .sliding(2)
-          .map {
+          .forall {
             case Seq(a) => true
             case Seq(a, b) => a.compatibleWith(b)
-          }.exists(!_)
+          }
 
       // Adds Exchange or Sort operators as required
       def addOperatorsIfNecessary(

http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f54aa20..eb4be19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -190,12 +190,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         sqlContext.conf.codegenEnabled).isDefined
     }
 
-    def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
-      case _: CombineSum | _: Sum | _: Count | _: Max | _: Min |  _: CombineSetsAndCount => false
+    def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
+      case _: CombineSum | _: Sum | _: Count | _: Max | _: Min |  _: CombineSetsAndCount => true
       // The generated set implementation is pretty limited ATM.
       case CollectHashSet(exprs) if exprs.size == 1  &&
-           Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
-      case _ => true
+           Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
+      case _ => false
     }
 
     def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =

http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
index ce1cbdc..b8e95a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -67,13 +67,6 @@ private[sql] abstract class SortAggregationIterator(
     functions
   }
 
-  // All non-algebraic aggregate functions.
-  protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
-    aggregateFunctions.collect {
-      case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
-    }.toArray
-  }
-
   // Positions of those non-algebraic aggregate functions in aggregateFunctions.
   // For example, we have func1, func2, func3, func4 in aggregateFunctions, and
   // func2 and func3 are non-algebraic aggregate functions.
@@ -91,6 +84,10 @@ private[sql] abstract class SortAggregationIterator(
     positions.toArray
   }
 
+  // All non-algebraic aggregate functions.
+  protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+    nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions)
+
   // This is used to project expressions for the grouping expressions.
   protected val groupGenerator =
     newMutableProjection(groupingExpressions, inputAttributes)()
@@ -179,8 +176,6 @@ private[sql] abstract class SortAggregationIterator(
       // For the below compare method, we do not need to make a copy of groupingKey.
       val groupingKey = groupGenerator(currentRow)
       // Check if the current row belongs the current input row.
-      currentGroupingKey.equals(groupingKey)
-
       if (currentGroupingKey == groupingKey) {
         processRow(currentRow)
       } else {
@@ -288,10 +283,7 @@ class PartialSortAggregationIterator(
 
   // This projection is used to update buffer values for all AlgebraicAggregates.
   private val algebraicUpdateProjection = {
-    val bufferSchema = aggregateFunctions.flatMap {
-      case ae: AlgebraicAggregate => ae.bufferAttributes
-      case agg: AggregateFunction2 => agg.bufferAttributes
-    }
+    val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes)
     val updateExpressions = aggregateFunctions.flatMap {
       case ae: AlgebraicAggregate => ae.updateExpressions
       case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
@@ -348,19 +340,14 @@ class PartialMergeSortAggregationIterator(
     inputAttributes,
     inputIter) {
 
-  private val placeholderAttribtues =
+  private val placeholderAttributes =
     Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
 
   // This projection is used to merge buffer values for all AlgebraicAggregates.
   private val algebraicMergeProjection = {
     val bufferSchemata =
-      placeholderAttribtues ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.bufferAttributes
-        case agg: AggregateFunction2 => agg.bufferAttributes
-      } ++ placeholderAttribtues ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
-        case agg: AggregateFunction2 => agg.cloneBufferAttributes
-      }
+      placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+        placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
     val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
       case ae: AlgebraicAggregate => ae.mergeExpressions
       case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
@@ -444,13 +431,8 @@ class FinalSortAggregationIterator(
   // This projection is used to merge buffer values for all AlgebraicAggregates.
   private val algebraicMergeProjection = {
     val bufferSchemata =
-      offsetAttributes ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.bufferAttributes
-        case agg: AggregateFunction2 => agg.bufferAttributes
-      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
-        case agg: AggregateFunction2 => agg.cloneBufferAttributes
-      }
+      offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+        offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
     val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
       case ae: AlgebraicAggregate => ae.mergeExpressions
       case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
@@ -462,13 +444,8 @@ class FinalSortAggregationIterator(
   // This projection is used to evaluate all AlgebraicAggregates.
   private val algebraicEvalProjection = {
     val bufferSchemata =
-      offsetAttributes ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.bufferAttributes
-        case agg: AggregateFunction2 => agg.bufferAttributes
-      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
-        case agg: AggregateFunction2 => agg.cloneBufferAttributes
-      }
+      offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+        offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
     val evalExpressions = aggregateFunctions.map {
       case ae: AlgebraicAggregate => ae.evaluateExpression
       case agg: AggregateFunction2 => NoOp
@@ -599,11 +576,10 @@ class FinalAndCompleteSortAggregationIterator(
   }
 
   // All non-algebraic aggregate functions with mode Final.
-  private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+  private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
     finalAggregateFunctions.collect {
       case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
-    }.toArray
-  }
+    }
 
   // All aggregate functions with mode Complete.
   private val completeAggregateFunctions: Array[AggregateFunction2] = {
@@ -617,11 +593,10 @@ class FinalAndCompleteSortAggregationIterator(
   }
 
   // All non-algebraic aggregate functions with mode Complete.
-  private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+  private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
     completeAggregateFunctions.collect {
       case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
-    }.toArray
-  }
+    }
 
   // This projection is used to merge buffer values for all AlgebraicAggregates with mode
   // Final.
@@ -633,13 +608,9 @@ class FinalAndCompleteSortAggregationIterator(
     val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
 
     val bufferSchemata =
-      offsetAttributes ++ finalAggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.bufferAttributes
-        case agg: AggregateFunction2 => agg.bufferAttributes
-      } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
-        case agg: AggregateFunction2 => agg.cloneBufferAttributes
-      } ++ completeOffsetAttributes
+      offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
+        completeOffsetAttributes ++ offsetAttributes ++
+        finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes
     val mergeExpressions =
       placeholderExpressions ++ finalAggregateFunctions.flatMap {
         case ae: AlgebraicAggregate => ae.mergeExpressions
@@ -658,10 +629,8 @@ class FinalAndCompleteSortAggregationIterator(
     val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
 
     val bufferSchema =
-      offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.bufferAttributes
-        case agg: AggregateFunction2 => agg.bufferAttributes
-      }
+      offsetAttributes ++ finalOffsetAttributes ++
+        completeAggregateFunctions.flatMap(_.bufferAttributes)
     val updateExpressions =
       placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
         case ae: AlgebraicAggregate => ae.updateExpressions
@@ -673,13 +642,8 @@ class FinalAndCompleteSortAggregationIterator(
   // This projection is used to evaluate all AlgebraicAggregates.
   private val algebraicEvalProjection = {
     val bufferSchemata =
-      offsetAttributes ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.bufferAttributes
-        case agg: AggregateFunction2 => agg.bufferAttributes
-      } ++ offsetAttributes ++ aggregateFunctions.flatMap {
-        case ae: AlgebraicAggregate => ae.cloneBufferAttributes
-        case agg: AggregateFunction2 => agg.cloneBufferAttributes
-      }
+      offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+        offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
     val evalExpressions = aggregateFunctions.map {
       case ae: AlgebraicAggregate => ae.evaluateExpression
       case agg: AggregateFunction2 => NoOp

http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 1cb2771..5bbe6c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -191,10 +191,7 @@ object Utils {
     }
     val groupExpressionMap = namedGroupingExpressions.toMap
     val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
-    val partialAggregateExpressions = aggregateExpressions.map {
-      case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
-        AggregateExpression2(aggregateFunction, Partial, isDistinct)
-    }
+    val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
     val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
       agg.aggregateFunction.bufferAttributes
     }
@@ -208,10 +205,7 @@ object Utils {
         child)
 
     // 2. Create an Aggregate Operator for final aggregations.
-    val finalAggregateExpressions = aggregateExpressions.map {
-      case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
-        AggregateExpression2(aggregateFunction, Final, isDistinct)
-    }
+    val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
     val finalAggregateAttributes =
       finalAggregateExpressions.map {
         expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)

http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index ab8dce6..95a1106 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1518,18 +1518,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
 
   test("SPARK-8945: add and subtract expressions for interval type") {
     import org.apache.spark.unsafe.types.Interval
+    import org.apache.spark.unsafe.types.Interval.MICROS_PER_WEEK
 
     val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
-    checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))
+    checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123)))
 
     checkAnswer(df.select(df("i") + new Interval(2, 123)),
-      Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))
+      Row(new Interval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123)))
 
     checkAnswer(df.select(df("i") - new Interval(2, 123)),
-      Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
+      Row(new Interval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123)))
 
     // unary minus
     checkAnswer(df.select(-df("i")),
-      Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
+      Row(new Interval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org