You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/24 08:40:07 UTC
spark git commit: [SPARK-9294][SQL] cleanup comments, code style,
naming typo for the new aggregation
Repository: spark
Updated Branches:
refs/heads/master d4d762f27 -> 408e64b28
[SPARK-9294][SQL] cleanup comments, code style, naming typo for the new aggregation
fix some comments and code style for https://github.com/apache/spark/pull/7458
Author: Wenchen Fan <cl...@outlook.com>
Closes #7619 from cloud-fan/agg-clean and squashes the following commits:
3925457 [Wenchen Fan] one more...
cc78357 [Wenchen Fan] one more cleanup
26f6a93 [Wenchen Fan] some minor cleanup for the new aggregation
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/408e64b2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/408e64b2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/408e64b2
Branch: refs/heads/master
Commit: 408e64b284ef8bd6796d815b5eb603312d090b74
Parents: d4d762f
Author: Wenchen Fan <cl...@outlook.com>
Authored: Thu Jul 23 23:40:01 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Thu Jul 23 23:40:01 2015 -0700
----------------------------------------------------------------------
.../spark/sql/catalyst/analysis/Analyzer.scala | 2 +-
.../expressions/aggregate/interfaces.scala | 18 ++---
.../apache/spark/sql/execution/Exchange.scala | 6 +-
.../spark/sql/execution/SparkStrategies.scala | 8 +-
.../aggregate/sortBasedIterators.scala | 82 ++++++--------------
.../spark/sql/execution/aggregate/utils.scala | 10 +--
.../org/apache/spark/sql/SQLQuerySuite.scala | 9 ++-
7 files changed, 46 insertions(+), 89 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8cadbc5..e916887 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -533,7 +533,7 @@ class Analyzer(
case min: Min if isDistinct => min
// For other aggregate functions, DISTINCT keyword is not supported for now.
// Once we converted to the new code path, we will allow using DISTINCT keyword.
- case other if isDistinct =>
+ case other: AggregateExpression1 if isDistinct =>
failAnalysis(s"$name does not support DISTINCT keyword.")
// If it does not have DISTINCT keyword, we will return it as is.
case other => other
http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index d3fee1a..10bd19c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -23,18 +23,18 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCod
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
-/** The mode of an [[AggregateFunction1]]. */
+/** The mode of an [[AggregateFunction2]]. */
private[sql] sealed trait AggregateMode
/**
- * An [[AggregateFunction1]] with [[Partial]] mode is used for partial aggregation.
+ * An [[AggregateFunction2]] with [[Partial]] mode is used for partial aggregation.
* This function updates the given aggregation buffer with the original input of this
* function. When it has processed all input rows, the aggregation buffer is returned.
*/
private[sql] case object Partial extends AggregateMode
/**
- * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
* containing intermediate results for this function.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the aggregation buffer is returned.
@@ -42,8 +42,8 @@ private[sql] case object Partial extends AggregateMode
private[sql] case object PartialMerge extends AggregateMode
/**
- * An [[AggregateFunction1]] with [[PartialMerge]] mode is used to merge aggregation buffers
- * containing intermediate results for this function and the generate final result.
+ * An [[AggregateFunction2]] with [[PartialMerge]] mode is used to merge aggregation buffers
+ * containing intermediate results for this function and then generate final result.
* This function updates the given aggregation buffer by merging multiple aggregation buffers.
* When it has processed all input rows, the final result of this function is returned.
*/
@@ -85,12 +85,12 @@ private[sql] case class AggregateExpression2(
override def nullable: Boolean = aggregateFunction.nullable
override def references: AttributeSet = {
- val childReferemces = mode match {
+ val childReferences = mode match {
case Partial | Complete => aggregateFunction.references.toSeq
case PartialMerge | Final => aggregateFunction.bufferAttributes
}
- AttributeSet(childReferemces)
+ AttributeSet(childReferences)
}
override def toString: String = s"(${aggregateFunction}2,mode=$mode,isDistinct=$isDistinct)"
@@ -99,10 +99,8 @@ private[sql] case class AggregateExpression2(
abstract class AggregateFunction2
extends Expression with ImplicitCastInputTypes {
- self: Product =>
-
/** An aggregate function is not foldable. */
- override def foldable: Boolean = false
+ final override def foldable: Boolean = false
/**
* The offset of this function's buffer in the underlying buffer shared with other functions.
http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index d31e265..41a0c51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -224,13 +224,13 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
// compatible.
// TODO: ASSUMES TRANSITIVITY?
def compatible: Boolean =
- !operator.children
+ operator.children
.map(_.outputPartitioning)
.sliding(2)
- .map {
+ .forall {
case Seq(a) => true
case Seq(a, b) => a.compatibleWith(b)
- }.exists(!_)
+ }
// Adds Exchange or Sort operators as required
def addOperatorsIfNecessary(
http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f54aa20..eb4be19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -190,12 +190,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
sqlContext.conf.codegenEnabled).isDefined
}
- def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = !aggs.exists {
- case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false
+ def canBeCodeGened(aggs: Seq[AggregateExpression1]): Boolean = aggs.forall {
+ case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => true
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
- Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
- case _ => true
+ Seq(IntegerType, LongType).contains(exprs.head.dataType) => true
+ case _ => false
}
def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression1] =
http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
index ce1cbdc..b8e95a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/sortBasedIterators.scala
@@ -67,13 +67,6 @@ private[sql] abstract class SortAggregationIterator(
functions
}
- // All non-algebraic aggregate functions.
- protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
- aggregateFunctions.collect {
- case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }.toArray
- }
-
// Positions of those non-algebraic aggregate functions in aggregateFunctions.
// For example, we have func1, func2, func3, func4 in aggregateFunctions, and
// func2 and func3 are non-algebraic aggregate functions.
@@ -91,6 +84,10 @@ private[sql] abstract class SortAggregationIterator(
positions.toArray
}
+ // All non-algebraic aggregate functions.
+ protected val nonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
+ nonAlgebraicAggregateFunctionPositions.map(aggregateFunctions)
+
// This is used to project expressions for the grouping expressions.
protected val groupGenerator =
newMutableProjection(groupingExpressions, inputAttributes)()
@@ -179,8 +176,6 @@ private[sql] abstract class SortAggregationIterator(
// For the below compare method, we do not need to make a copy of groupingKey.
val groupingKey = groupGenerator(currentRow)
// Check if the current row belongs the current input row.
- currentGroupingKey.equals(groupingKey)
-
if (currentGroupingKey == groupingKey) {
processRow(currentRow)
} else {
@@ -288,10 +283,7 @@ class PartialSortAggregationIterator(
// This projection is used to update buffer values for all AlgebraicAggregates.
private val algebraicUpdateProjection = {
- val bufferSchema = aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.bufferAttributes
- case agg: AggregateFunction2 => agg.bufferAttributes
- }
+ val bufferSchema = aggregateFunctions.flatMap(_.bufferAttributes)
val updateExpressions = aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.updateExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
@@ -348,19 +340,14 @@ class PartialMergeSortAggregationIterator(
inputAttributes,
inputIter) {
- private val placeholderAttribtues =
+ private val placeholderAttributes =
Seq.fill(initialBufferOffset)(AttributeReference("placeholder", NullType)())
// This projection is used to merge buffer values for all AlgebraicAggregates.
private val algebraicMergeProjection = {
val bufferSchemata =
- placeholderAttribtues ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.bufferAttributes
- case agg: AggregateFunction2 => agg.bufferAttributes
- } ++ placeholderAttribtues ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.cloneBufferAttributes
- case agg: AggregateFunction2 => agg.cloneBufferAttributes
- }
+ placeholderAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ placeholderAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
@@ -444,13 +431,8 @@ class FinalSortAggregationIterator(
// This projection is used to merge buffer values for all AlgebraicAggregates.
private val algebraicMergeProjection = {
val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.bufferAttributes
- case agg: AggregateFunction2 => agg.bufferAttributes
- } ++ offsetAttributes ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.cloneBufferAttributes
- case agg: AggregateFunction2 => agg.cloneBufferAttributes
- }
+ offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
val mergeExpressions = placeholderExpressions ++ aggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
case agg: AggregateFunction2 => Seq.fill(agg.bufferAttributes.length)(NoOp)
@@ -462,13 +444,8 @@ class FinalSortAggregationIterator(
// This projection is used to evaluate all AlgebraicAggregates.
private val algebraicEvalProjection = {
val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.bufferAttributes
- case agg: AggregateFunction2 => agg.bufferAttributes
- } ++ offsetAttributes ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.cloneBufferAttributes
- case agg: AggregateFunction2 => agg.cloneBufferAttributes
- }
+ offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
@@ -599,11 +576,10 @@ class FinalAndCompleteSortAggregationIterator(
}
// All non-algebraic aggregate functions with mode Final.
- private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ private val finalNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
finalAggregateFunctions.collect {
case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }.toArray
- }
+ }
// All aggregate functions with mode Complete.
private val completeAggregateFunctions: Array[AggregateFunction2] = {
@@ -617,11 +593,10 @@ class FinalAndCompleteSortAggregationIterator(
}
// All non-algebraic aggregate functions with mode Complete.
- private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] = {
+ private val completeNonAlgebraicAggregateFunctions: Array[AggregateFunction2] =
completeAggregateFunctions.collect {
case func: AggregateFunction2 if !func.isInstanceOf[AlgebraicAggregate] => func
- }.toArray
- }
+ }
// This projection is used to merge buffer values for all AlgebraicAggregates with mode
// Final.
@@ -633,13 +608,9 @@ class FinalAndCompleteSortAggregationIterator(
val completeOffsetExpressions = Seq.fill(numCompleteOffsetAttributes)(NoOp)
val bufferSchemata =
- offsetAttributes ++ finalAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.bufferAttributes
- case agg: AggregateFunction2 => agg.bufferAttributes
- } ++ completeOffsetAttributes ++ offsetAttributes ++ finalAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.cloneBufferAttributes
- case agg: AggregateFunction2 => agg.cloneBufferAttributes
- } ++ completeOffsetAttributes
+ offsetAttributes ++ finalAggregateFunctions.flatMap(_.bufferAttributes) ++
+ completeOffsetAttributes ++ offsetAttributes ++
+ finalAggregateFunctions.flatMap(_.cloneBufferAttributes) ++ completeOffsetAttributes
val mergeExpressions =
placeholderExpressions ++ finalAggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.mergeExpressions
@@ -658,10 +629,8 @@ class FinalAndCompleteSortAggregationIterator(
val finalOffsetExpressions = Seq.fill(numFinalOffsetAttributes)(NoOp)
val bufferSchema =
- offsetAttributes ++ finalOffsetAttributes ++ completeAggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.bufferAttributes
- case agg: AggregateFunction2 => agg.bufferAttributes
- }
+ offsetAttributes ++ finalOffsetAttributes ++
+ completeAggregateFunctions.flatMap(_.bufferAttributes)
val updateExpressions =
placeholderExpressions ++ finalOffsetExpressions ++ completeAggregateFunctions.flatMap {
case ae: AlgebraicAggregate => ae.updateExpressions
@@ -673,13 +642,8 @@ class FinalAndCompleteSortAggregationIterator(
// This projection is used to evaluate all AlgebraicAggregates.
private val algebraicEvalProjection = {
val bufferSchemata =
- offsetAttributes ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.bufferAttributes
- case agg: AggregateFunction2 => agg.bufferAttributes
- } ++ offsetAttributes ++ aggregateFunctions.flatMap {
- case ae: AlgebraicAggregate => ae.cloneBufferAttributes
- case agg: AggregateFunction2 => agg.cloneBufferAttributes
- }
+ offsetAttributes ++ aggregateFunctions.flatMap(_.bufferAttributes) ++
+ offsetAttributes ++ aggregateFunctions.flatMap(_.cloneBufferAttributes)
val evalExpressions = aggregateFunctions.map {
case ae: AlgebraicAggregate => ae.evaluateExpression
case agg: AggregateFunction2 => NoOp
http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 1cb2771..5bbe6c1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -191,10 +191,7 @@ object Utils {
}
val groupExpressionMap = namedGroupingExpressions.toMap
val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute)
- val partialAggregateExpressions = aggregateExpressions.map {
- case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
- AggregateExpression2(aggregateFunction, Partial, isDistinct)
- }
+ val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial))
val partialAggregateAttributes = partialAggregateExpressions.flatMap { agg =>
agg.aggregateFunction.bufferAttributes
}
@@ -208,10 +205,7 @@ object Utils {
child)
// 2. Create an Aggregate Operator for final aggregations.
- val finalAggregateExpressions = aggregateExpressions.map {
- case AggregateExpression2(aggregateFunction, mode, isDistinct) =>
- AggregateExpression2(aggregateFunction, Final, isDistinct)
- }
+ val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final))
val finalAggregateAttributes =
finalAggregateExpressions.map {
expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)
http://git-wip-us.apache.org/repos/asf/spark/blob/408e64b2/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index ab8dce6..95a1106 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -1518,18 +1518,19 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
test("SPARK-8945: add and subtract expressions for interval type") {
import org.apache.spark.unsafe.types.Interval
+ import org.apache.spark.unsafe.types.Interval.MICROS_PER_WEEK
val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
- checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))
+ checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * MICROS_PER_WEEK + 123)))
checkAnswer(df.select(df("i") + new Interval(2, 123)),
- Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))
+ Row(new Interval(12 * 3 - 3 + 2, 7L * MICROS_PER_WEEK + 123 + 123)))
checkAnswer(df.select(df("i") - new Interval(2, 123)),
- Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))
+ Row(new Interval(12 * 3 - 3 - 2, 7L * MICROS_PER_WEEK + 123 - 123)))
// unary minus
checkAnswer(df.select(-df("i")),
- Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
+ Row(new Interval(-(12 * 3 - 3), -(7L * MICROS_PER_WEEK + 123))))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org