You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/04/23 02:44:05 UTC
[5/7] spark git commit: [SPARK-14855][SQL] Add "Exec" suffix to
physical operators
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
new file mode 100644
index 0000000..97bbab6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
@@ -0,0 +1,1008 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.util
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator}
+
+/**
+ * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted)
+ * partition. The aggregates are calculated for each row in the group. Special processing
+ * instructions, frames, are used to calculate these aggregates. Frames are processed in the order
+ * specified in the window specification (the ORDER BY ... clause). There are four different frame
+ * types:
+ * - Entire partition: The frame is the entire partition, i.e.
+ * UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING. For this case, window function will take all
+ * rows as inputs and be evaluated once.
+ * - Growing frame: We only add new rows into the frame, i.e. UNBOUNDED PRECEDING AND ....
+ * Every time we move to a new row to process, we add some rows to the frame. We do not remove
+ * rows from this frame.
+ * - Shrinking frame: We only remove rows from the frame, i.e. ... AND UNBOUNDED FOLLOWING.
+ * Every time we move to a new row to process, we remove some rows from the frame. We do not add
+ * rows to this frame.
+ * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame
+ * and we add some rows to the frame. Examples are:
+ * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING.
+ * - Offset frame: The frame consist of one row, which is an offset number of rows away from the
+ * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame.
+ *
+ * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame
+ * boundary can be either Row or Range based:
+ * - Row Based: A row based boundary is based on the position of the row within the partition.
+ * An offset indicates the number of rows above or below the current row, the frame for the
+ * current row starts or ends. For instance, given a row based sliding frame with a lower bound
+ * offset of -1 and a upper bound offset of +2. The frame for row with index 5 would range from
+ * index 4 to index 6.
+ * - Range based: A range based boundary is based on the actual value of the ORDER BY
+ * expression(s). An offset is used to alter the value of the ORDER BY expression, for
+ * instance if the current order by expression has a value of 10 and the lower bound offset
+ * is -3, the resulting lower bound for the current row will be 10 - 3 = 7. This however puts a
+ * number of constraints on the ORDER BY expressions: there can be only one expression and this
+ * expression must have a numerical data type. An exception can be made when the offset is 0,
+ * because no value modification is needed, in this case multiple and non-numeric ORDER BY
+ * expression are allowed.
+ *
+ * This is quite an expensive operator because every row for a single group must be in the same
+ * partition and partitions must be sorted according to the grouping and sort order. The operator
+ * requires the planner to take care of the partitioning and sorting.
+ *
+ * The operator is semi-blocking. The window functions and aggregates are calculated one group at
+ * a time, the result will only be made available after the processing for the entire group has
+ * finished. The operator is able to process different frame configurations at the same time. This
+ * is done by delegating the actual frame processing (i.e. calculation of the window functions) to
+ * specialized classes, see [[WindowFunctionFrame]], which take care of their own frame type:
+ * Entire Partition, Sliding, Growing & Shrinking. Boundary evaluation is also delegated to a pair
+ * of specialized classes: [[RowBoundOrdering]] & [[RangeBoundOrdering]].
+ */
+case class WindowExec(
+ windowExpression: Seq[NamedExpression],
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ child: SparkPlan)
+ extends UnaryExecNode {
+
+ override def output: Seq[Attribute] =
+ child.output ++ windowExpression.map(_.toAttribute)
+
+ override def requiredChildDistribution: Seq[Distribution] = {
+ if (partitionSpec.isEmpty) {
+ // Only show warning when the number of bytes is larger than 100 MB?
+ logWarning("No Partition Defined for Window operation! Moving all data to a single "
+ + "partition, this can cause serious performance degradation.")
+ AllTuples :: Nil
+ } else ClusteredDistribution(partitionSpec) :: Nil
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
+
+ override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+ /**
+ * Create a bound ordering object for a given frame type and offset. A bound ordering object is
+ * used to determine which input row lies within the frame boundaries of an output row.
+ *
+ * This method uses Code Generation. It can only be used on the executor side.
+ *
+ * @param frameType to evaluate. This can either be Row or Range based.
+ * @param offset with respect to the row.
+ * @return a bound ordering object.
+ */
+ private[this] def createBoundOrdering(frameType: FrameType, offset: Int): BoundOrdering = {
+ frameType match {
+ case RangeFrame =>
+ val (exprs, current, bound) = if (offset == 0) {
+ // Use the entire order expression when the offset is 0.
+ val exprs = orderSpec.map(_.child)
+ val buildProjection = () => newMutableProjection(exprs, child.output)
+ (orderSpec, buildProjection(), buildProjection())
+ } else if (orderSpec.size == 1) {
+ // Use only the first order expression when the offset is non-null.
+ val sortExpr = orderSpec.head
+ val expr = sortExpr.child
+ // Create the projection which returns the current 'value'.
+ val current = newMutableProjection(expr :: Nil, child.output)
+ // Flip the sign of the offset when processing the order is descending
+ val boundOffset = sortExpr.direction match {
+ case Descending => -offset
+ case Ascending => offset
+ }
+ // Create the projection which returns the current 'value' modified by adding the offset.
+ val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType))
+ val bound = newMutableProjection(boundExpr :: Nil, child.output)
+ (sortExpr :: Nil, current, bound)
+ } else {
+ sys.error("Non-Zero range offsets are not supported for windows " +
+ "with multiple order expressions.")
+ }
+ // Construct the ordering. This is used to compare the result of current value projection
+ // to the result of bound value projection. This is done manually because we want to use
+ // Code Generation (if it is enabled).
+ val sortExprs = exprs.zipWithIndex.map { case (e, i) =>
+ SortOrder(BoundReference(i, e.dataType, e.nullable), e.direction)
+ }
+ val ordering = newOrdering(sortExprs, Nil)
+ RangeBoundOrdering(ordering, current, bound)
+ case RowFrame => RowBoundOrdering(offset)
+ }
+ }
+
+ /**
+ * Collection containing an entry for each window frame to process. Each entry contains a frames'
+ * WindowExpressions and factory function for the WindowFrameFunction.
+ */
+ private[this] lazy val windowFrameExpressionFactoryPairs = {
+ type FrameKey = (String, FrameType, Option[Int], Option[Int])
+ type ExpressionBuffer = mutable.Buffer[Expression]
+ val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
+
+ // Add a function and its function to the map for a given frame.
+ def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
+ val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd))
+ val (es, fns) = framedFunctions.getOrElseUpdate(
+ key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
+ es.append(e)
+ fns.append(fn)
+ }
+
+ // Collect all valid window functions and group them by their frame.
+ windowExpression.foreach { x =>
+ x.foreach {
+ case e @ WindowExpression(function, spec) =>
+ val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+ function match {
+ case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
+ case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
+ case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
+ case f => sys.error(s"Unsupported window function: $f")
+ }
+ case _ =>
+ }
+ }
+
+ // Map the groups to a (unbound) expression and frame factory pair.
+ var numExpressions = 0
+ framedFunctions.toSeq.map {
+ case (key, (expressions, functionSeq)) =>
+ val ordinal = numExpressions
+ val functions = functionSeq.toArray
+
+ // Construct an aggregate processor if we need one.
+ def processor = AggregateProcessor(
+ functions,
+ ordinal,
+ child.output,
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
+
+ // Create the factory
+ val factory = key match {
+ // Offset Frame
+ case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h =>
+ target: MutableRow =>
+ new OffsetWindowFunctionFrame(
+ target,
+ ordinal,
+ functions,
+ child.output,
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
+ offset)
+
+ // Growing Frame.
+ case ("AGGREGATE", frameType, None, Some(high)) =>
+ target: MutableRow => {
+ new UnboundedPrecedingWindowFunctionFrame(
+ target,
+ processor,
+ createBoundOrdering(frameType, high))
+ }
+
+ // Shrinking Frame.
+ case ("AGGREGATE", frameType, Some(low), None) =>
+ target: MutableRow => {
+ new UnboundedFollowingWindowFunctionFrame(
+ target,
+ processor,
+ createBoundOrdering(frameType, low))
+ }
+
+ // Moving Frame.
+ case ("AGGREGATE", frameType, Some(low), Some(high)) =>
+ target: MutableRow => {
+ new SlidingWindowFunctionFrame(
+ target,
+ processor,
+ createBoundOrdering(frameType, low),
+ createBoundOrdering(frameType, high))
+ }
+
+ // Entire Partition Frame.
+ case ("AGGREGATE", frameType, None, None) =>
+ target: MutableRow => {
+ new UnboundedWindowFunctionFrame(target, processor)
+ }
+ }
+
+ // Keep track of the number of expressions. This is a side-effect in a map...
+ numExpressions += expressions.size
+
+ // Create the Frame Expression - Factory pair.
+ (expressions, factory)
+ }
+ }
+
+ /**
+ * Create the resulting projection.
+ *
+ * This method uses Code Generation. It can only be used on the executor side.
+ *
+ * @param expressions unbound ordered function expressions.
+ * @return the final resulting projection.
+ */
+ private[this] def createResultProjection(
+ expressions: Seq[Expression]): UnsafeProjection = {
+ val references = expressions.zipWithIndex.map{ case (e, i) =>
+ // Results of window expressions will be on the right side of child's output
+ BoundReference(child.output.size + i, e.dataType, e.nullable)
+ }
+ val unboundToRefMap = expressions.zip(references).toMap
+ val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
+ UnsafeProjection.create(
+ child.output ++ patchedWindowExpression,
+ child.output)
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ // Unwrap the expressions and factories from the map.
+ val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
+ val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+
+ // Start processing.
+ child.execute().mapPartitions { stream =>
+ new Iterator[InternalRow] {
+
+ // Get all relevant projections.
+ val result = createResultProjection(expressions)
+ val grouping = UnsafeProjection.create(partitionSpec, child.output)
+
+ // Manage the stream and the grouping.
+ var nextRow: UnsafeRow = null
+ var nextGroup: UnsafeRow = null
+ var nextRowAvailable: Boolean = false
+ private[this] def fetchNextRow() {
+ nextRowAvailable = stream.hasNext
+ if (nextRowAvailable) {
+ nextRow = stream.next().asInstanceOf[UnsafeRow]
+ nextGroup = grouping(nextRow)
+ } else {
+ nextRow = null
+ nextGroup = null
+ }
+ }
+ fetchNextRow()
+
+ // Manage the current partition.
+ val rows = ArrayBuffer.empty[UnsafeRow]
+ val inputFields = child.output.length
+ var sorter: UnsafeExternalSorter = null
+ var rowBuffer: RowBuffer = null
+ val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType))
+ val frames = factories.map(_(windowFunctionResult))
+ val numFrames = frames.length
+ private[this] def fetchNextPartition() {
+ // Collect all the rows in the current partition.
+ // Before we start to fetch new input rows, make a copy of nextGroup.
+ val currentGroup = nextGroup.copy()
+
+ // clear last partition
+ if (sorter != null) {
+ // the last sorter of this task will be cleaned up via task completion listener
+ sorter.cleanupResources()
+ sorter = null
+ } else {
+ rows.clear()
+ }
+
+ while (nextRowAvailable && nextGroup == currentGroup) {
+ if (sorter == null) {
+ rows += nextRow.copy()
+
+ if (rows.length >= 4096) {
+ // We will not sort the rows, so prefixComparator and recordComparator are null.
+ sorter = UnsafeExternalSorter.create(
+ TaskContext.get().taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ TaskContext.get(),
+ null,
+ null,
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ false)
+ rows.foreach { r =>
+ sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
+ }
+ rows.clear()
+ }
+ } else {
+ sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset,
+ nextRow.getSizeInBytes, 0)
+ }
+ fetchNextRow()
+ }
+ if (sorter != null) {
+ rowBuffer = new ExternalRowBuffer(sorter, inputFields)
+ } else {
+ rowBuffer = new ArrayRowBuffer(rows)
+ }
+
+ // Setup the frames.
+ var i = 0
+ while (i < numFrames) {
+ frames(i).prepare(rowBuffer.copy())
+ i += 1
+ }
+
+ // Setup iteration
+ rowIndex = 0
+ rowsSize = rowBuffer.size()
+ }
+
+ // Iteration
+ var rowIndex = 0
+ var rowsSize = 0L
+
+ override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable
+
+ val join = new JoinedRow
+ override final def next(): InternalRow = {
+ // Load the next partition if we need to.
+ if (rowIndex >= rowsSize && nextRowAvailable) {
+ fetchNextPartition()
+ }
+
+ if (rowIndex < rowsSize) {
+ // Get the results for the window frames.
+ var i = 0
+ val current = rowBuffer.next()
+ while (i < numFrames) {
+ frames(i).write(rowIndex, current)
+ i += 1
+ }
+
+ // 'Merge' the input row with the window function result
+ join(current, windowFunctionResult)
+ rowIndex += 1
+
+ // Return the projection.
+ result(join)
+ } else throw new NoSuchElementException
+ }
+ }
+ }
+ }
+}
+
+/**
+ * Function for comparing boundary values.
+ */
+private[execution] abstract class BoundOrdering {
+ def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int
+}
+
+/**
+ * Compare the input index to the bound of the output index.
+ */
+private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering {
+ override def compare(
+ inputRow: InternalRow,
+ inputIndex: Int,
+ outputRow: InternalRow,
+ outputIndex: Int): Int =
+ inputIndex - (outputIndex + offset)
+}
+
+/**
+ * Compare the value of the input index to the value bound of the output index.
+ */
+private[execution] final case class RangeBoundOrdering(
+ ordering: Ordering[InternalRow],
+ current: Projection,
+ bound: Projection) extends BoundOrdering {
+ override def compare(
+ inputRow: InternalRow,
+ inputIndex: Int,
+ outputRow: InternalRow,
+ outputIndex: Int): Int =
+ ordering.compare(current(inputRow), bound(outputRow))
+}
+
+/**
+ * The interface of row buffer for a partition
+ */
+private[execution] abstract class RowBuffer {
+
+ /** Number of rows. */
+ def size(): Int
+
+ /** Return next row in the buffer, null if no more left. */
+ def next(): InternalRow
+
+ /** Skip the next `n` rows. */
+ def skip(n: Int): Unit
+
+ /** Return a new RowBuffer that has the same rows. */
+ def copy(): RowBuffer
+}
+
+/**
+ * A row buffer based on ArrayBuffer (the number of rows is limited)
+ */
+private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer {
+
+ private[this] var cursor: Int = -1
+
+ /** Number of rows. */
+ def size(): Int = buffer.length
+
+ /** Return next row in the buffer, null if no more left. */
+ def next(): InternalRow = {
+ cursor += 1
+ if (cursor < buffer.length) {
+ buffer(cursor)
+ } else {
+ null
+ }
+ }
+
+ /** Skip the next `n` rows. */
+ def skip(n: Int): Unit = {
+ cursor += n
+ }
+
+ /** Return a new RowBuffer that has the same rows. */
+ def copy(): RowBuffer = {
+ new ArrayRowBuffer(buffer)
+ }
+}
+
+/**
+ * An external buffer of rows based on UnsafeExternalSorter
+ */
+private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int)
+ extends RowBuffer {
+
+ private[this] val iter: UnsafeSorterIterator = sorter.getIterator
+
+ private[this] val currentRow = new UnsafeRow(numFields)
+
+ /** Number of rows. */
+ def size(): Int = iter.getNumRecords()
+
+ /** Return next row in the buffer, null if no more left. */
+ def next(): InternalRow = {
+ if (iter.hasNext) {
+ iter.loadNext()
+ currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+ currentRow
+ } else {
+ null
+ }
+ }
+
+ /** Skip the next `n` rows. */
+ def skip(n: Int): Unit = {
+ var i = 0
+ while (i < n && iter.hasNext) {
+ iter.loadNext()
+ i += 1
+ }
+ }
+
+ /** Return a new RowBuffer that has the same rows. */
+ def copy(): RowBuffer = {
+ new ExternalRowBuffer(sorter, numFields)
+ }
+}
+
+/**
+ * A window function calculates the results of a number of window functions for a window frame.
+ * Before use a frame must be prepared by passing it all the rows in the current partition. After
+ * preparation the update method can be called to fill the output rows.
+ */
+private[execution] abstract class WindowFunctionFrame {
+ /**
+ * Prepare the frame for calculating the results for a partition.
+ *
+ * @param rows to calculate the frame results for.
+ */
+ def prepare(rows: RowBuffer): Unit
+
+ /**
+ * Write the current results to the target row.
+ */
+ def write(index: Int, current: InternalRow): Unit
+}
+
+/**
+ * The offset window frame calculates frames containing LEAD/LAG statements.
+ *
+ * @param target to write results to.
+ * @param expressions to shift a number of rows.
+ * @param inputSchema required for creating a projection.
+ * @param newMutableProjection function used to create the projection.
+ * @param offset by which rows get moved within a partition.
+ */
+private[execution] final class OffsetWindowFunctionFrame(
+ target: MutableRow,
+ ordinal: Int,
+ expressions: Array[Expression],
+ inputSchema: Seq[Attribute],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection,
+ offset: Int) extends WindowFunctionFrame {
+
+ /** Rows of the partition currently being processed. */
+ private[this] var input: RowBuffer = null
+
+ /** Index of the input row currently used for output. */
+ private[this] var inputIndex = 0
+
+ /** Row used when there is no valid input. */
+ private[this] val emptyRow = new GenericInternalRow(inputSchema.size)
+
+ /** Row used to combine the offset and the current row. */
+ private[this] val join = new JoinedRow
+
+ /** Create the projection. */
+ private[this] val projection = {
+ // Collect the expressions and bind them.
+ val inputAttrs = inputSchema.map(_.withNullability(true))
+ val numInputAttributes = inputAttrs.size
+ val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map {
+ case e: OffsetWindowFunction =>
+ val input = BindReferences.bindReference(e.input, inputAttrs)
+ if (e.default == null || e.default.foldable && e.default.eval() == null) {
+ // Without default value.
+ input
+ } else {
+ // With default value.
+ val default = BindReferences.bindReference(e.default, inputAttrs).transform {
+ // Shift the input reference to its default version.
+ case BoundReference(o, dataType, nullable) =>
+ BoundReference(o + numInputAttributes, dataType, nullable)
+ }
+ org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil)
+ }
+ case e =>
+ BindReferences.bindReference(e, inputAttrs)
+ }
+
+ // Create the projection.
+ newMutableProjection(boundExpressions, Nil).target(target)
+ }
+
+ override def prepare(rows: RowBuffer): Unit = {
+ input = rows
+ // drain the first few rows if offset is larger than zero
+ inputIndex = 0
+ while (inputIndex < offset) {
+ input.next()
+ inputIndex += 1
+ }
+ inputIndex = offset
+ }
+
+ override def write(index: Int, current: InternalRow): Unit = {
+ if (inputIndex >= 0 && inputIndex < input.size) {
+ val r = input.next()
+ join(r, current)
+ } else {
+ join(emptyRow, current)
+ }
+ projection(join)
+ inputIndex += 1
+ }
+}
+
+/**
+ * The sliding window frame calculates frames with the following SQL form:
+ * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING
+ *
+ * @param target to write results to.
+ * @param processor to calculate the row values with.
+ * @param lbound comparator used to identify the lower bound of an output row.
+ * @param ubound comparator used to identify the upper bound of an output row.
+ */
+private[execution] final class SlidingWindowFunctionFrame(
+ target: MutableRow,
+ processor: AggregateProcessor,
+ lbound: BoundOrdering,
+ ubound: BoundOrdering) extends WindowFunctionFrame {
+
+ /** Rows of the partition currently being processed. */
+ private[this] var input: RowBuffer = null
+
+ /** The next row from `input`. */
+ private[this] var nextRow: InternalRow = null
+
+ /** The rows within current sliding window. */
+ private[this] val buffer = new util.ArrayDeque[InternalRow]()
+
+ /**
+ * Index of the first input row with a value greater than the upper bound of the current
+ * output row.
+ */
+ private[this] var inputHighIndex = 0
+
+ /**
+ * Index of the first input row with a value equal to or greater than the lower bound of the
+ * current output row.
+ */
+ private[this] var inputLowIndex = 0
+
+ /** Prepare the frame for calculating a new partition. Reset all variables. */
+ override def prepare(rows: RowBuffer): Unit = {
+ input = rows
+ nextRow = rows.next()
+ inputHighIndex = 0
+ inputLowIndex = 0
+ buffer.clear()
+ }
+
+ /** Write the frame columns for the current row to the given target row. */
+ override def write(index: Int, current: InternalRow): Unit = {
+ var bufferUpdated = index == 0
+
+ // Add all rows to the buffer for which the input row value is equal to or less than
+ // the output row upper bound.
+ while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
+ buffer.add(nextRow.copy())
+ nextRow = input.next()
+ inputHighIndex += 1
+ bufferUpdated = true
+ }
+
+ // Drop all rows from the buffer for which the input row value is smaller than
+ // the output row lower bound.
+ while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
+ buffer.remove()
+ inputLowIndex += 1
+ bufferUpdated = true
+ }
+
+ // Only recalculate and update when the buffer changes.
+ if (bufferUpdated) {
+ processor.initialize(input.size)
+ val iter = buffer.iterator()
+ while (iter.hasNext) {
+ processor.update(iter.next())
+ }
+ processor.evaluate(target)
+ }
+ }
+}
+
+/**
+ * The unbounded window frame calculates frames with the following SQL forms:
+ * ... (No Frame Definition)
+ * ... BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ *
+ * Its results are the same for each and every row in the partition. This class can be seen as a
+ * special case of a sliding window, but is optimized for the unbound case.
+ *
+ * @param target to write results to.
+ * @param processor to calculate the row values with.
+ */
+private[execution] final class UnboundedWindowFunctionFrame(
+ target: MutableRow,
+ processor: AggregateProcessor) extends WindowFunctionFrame {
+
+ /** Prepare the frame for calculating a new partition. Process all rows eagerly. */
+ override def prepare(rows: RowBuffer): Unit = {
+ val size = rows.size()
+ processor.initialize(size)
+ var i = 0
+ while (i < size) {
+ processor.update(rows.next())
+ i += 1
+ }
+ }
+
+ /** Write the frame columns for the current row to the given target row. */
+ override def write(index: Int, current: InternalRow): Unit = {
+ // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate
+ // for each row.
+ processor.evaluate(target)
+ }
+}
+
+/**
+ * The UnboundPreceding window frame calculates frames with the following SQL form:
+ * ... BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
+ *
+ * There is only an upper bound. Very common use cases are for instance running sums or counts
+ * (row_number). Technically this is a special case of a sliding window. However a sliding window
+ * has to maintain a buffer, and it must do a full evaluation everytime the buffer changes. This
+ * is not the case when there is no lower bound, given the additive nature of most aggregates
+ * streaming updates and partial evaluation suffice and no buffering is needed.
+ *
+ * @param target to write results to.
+ * @param processor to calculate the row values with.
+ * @param ubound comparator used to identify the upper bound of an output row.
+ */
+private[execution] final class UnboundedPrecedingWindowFunctionFrame(
+ target: MutableRow,
+ processor: AggregateProcessor,
+ ubound: BoundOrdering) extends WindowFunctionFrame {
+
+ /** Rows of the partition currently being processed. */
+ private[this] var input: RowBuffer = null
+
+ /** The next row from `input`. */
+ private[this] var nextRow: InternalRow = null
+
+ /**
+ * Index of the first input row with a value greater than the upper bound of the current
+ * output row.
+ */
+ private[this] var inputIndex = 0
+
+ /** Prepare the frame for calculating a new partition. */
+ override def prepare(rows: RowBuffer): Unit = {
+ input = rows
+ nextRow = rows.next()
+ inputIndex = 0
+ processor.initialize(input.size)
+ }
+
+ /** Write the frame columns for the current row to the given target row. */
+ override def write(index: Int, current: InternalRow): Unit = {
+ var bufferUpdated = index == 0
+
+ // Add all rows to the aggregates for which the input row value is equal to or less than
+ // the output row upper bound.
+ while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
+ processor.update(nextRow)
+ nextRow = input.next()
+ inputIndex += 1
+ bufferUpdated = true
+ }
+
+ // Only recalculate and update when the buffer changes.
+ if (bufferUpdated) {
+ processor.evaluate(target)
+ }
+ }
+}
+
+/**
+ * The UnboundFollowing window frame calculates frames with the following SQL form:
+ * ... BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING
+ *
+ * There is only an upper bound. This is a slightly modified version of the sliding window. The
+ * sliding window operator has to check if both upper and the lower bound change when a new row
+ * gets processed, where as the unbounded following only has to check the lower bound.
+ *
+ * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a
+ * buffer and must do full recalculation after each row. Reverse iteration would be possible, if
+ * the commutativity of the used window functions can be guaranteed.
+ *
+ * @param target to write results to.
+ * @param processor to calculate the row values with.
+ * @param lbound comparator used to identify the lower bound of an output row.
+ */
+private[execution] final class UnboundedFollowingWindowFunctionFrame(
+ target: MutableRow,
+ processor: AggregateProcessor,
+ lbound: BoundOrdering) extends WindowFunctionFrame {
+
+ /** Rows of the partition currently being processed. */
+ private[this] var input: RowBuffer = null
+
+ /**
+ * Index of the first input row with a value equal to or greater than the lower bound of the
+ * current output row.
+ */
+ private[this] var inputIndex = 0
+
+ /** Prepare the frame for calculating a new partition. */
+ override def prepare(rows: RowBuffer): Unit = {
+ input = rows
+ inputIndex = 0
+ }
+
+ /** Write the frame columns for the current row to the given target row. */
+ override def write(index: Int, current: InternalRow): Unit = {
+ var bufferUpdated = index == 0
+
+ // Duplicate the input to have a new iterator
+ val tmp = input.copy()
+
+ // Drop all rows from the buffer for which the input row value is smaller than
+ // the output row lower bound.
+ tmp.skip(inputIndex)
+ var nextRow = tmp.next()
+ while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) {
+ nextRow = tmp.next()
+ inputIndex += 1
+ bufferUpdated = true
+ }
+
+ // Only recalculate and update when the buffer changes.
+ if (bufferUpdated) {
+ processor.initialize(input.size)
+ while (nextRow != null) {
+ processor.update(nextRow)
+ nextRow = tmp.next()
+ }
+ processor.evaluate(target)
+ }
+ }
+}
+
+/**
+ * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a
+ * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way,
+ * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying
+ * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode.
+ *
+ * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions
+ * require the size of the partition processed, this value is exposed to them when the processor is
+ * constructed.
+ *
+ * Processing of distinct aggregates is currently not supported.
+ *
+ * The implementation is split into an object which takes care of construction, and a the actual
+ * processor class.
+ */
+private[execution] object AggregateProcessor {
+ def apply(
+ functions: Array[Expression],
+ ordinal: Int,
+ inputAttributes: Seq[Attribute],
+ newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection):
+ AggregateProcessor = {
+ val aggBufferAttributes = mutable.Buffer.empty[AttributeReference]
+ val initialValues = mutable.Buffer.empty[Expression]
+ val updateExpressions = mutable.Buffer.empty[Expression]
+ val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp)
+ val imperatives = mutable.Buffer.empty[ImperativeAggregate]
+
+ // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then
+ // serialized to executor side. These functions all reference a global singleton window
+ // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect
+ // the singleton instance created on driver side instead of using executor side
+ // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID.
+ val partitionSize: Option[AttributeReference] = {
+ val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f })
+ aggs.headOption.map(_.n)
+ }
+
+ // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to
+ // the aggregation buffer. Note that the ordinal of the partition size value will always be 0.
+ partitionSize.foreach { n =>
+ aggBufferAttributes += n
+ initialValues += NoOp
+ updateExpressions += NoOp
+ }
+
+ // Add an AggregateFunction to the AggregateProcessor.
+ functions.foreach {
+ case agg: DeclarativeAggregate =>
+ aggBufferAttributes ++= agg.aggBufferAttributes
+ initialValues ++= agg.initialValues
+ updateExpressions ++= agg.updateExpressions
+ evaluateExpressions += agg.evaluateExpression
+ case agg: ImperativeAggregate =>
+ val offset = aggBufferAttributes.size
+ val imperative = BindReferences.bindReference(agg
+ .withNewInputAggBufferOffset(offset)
+ .withNewMutableAggBufferOffset(offset),
+ inputAttributes)
+ imperatives += imperative
+ aggBufferAttributes ++= imperative.aggBufferAttributes
+ val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp)
+ initialValues ++= noOps
+ updateExpressions ++= noOps
+ evaluateExpressions += imperative
+ case other =>
+ sys.error(s"Unsupported Aggregate Function: $other")
+ }
+
+ // Create the projections.
+ val initialProjection = newMutableProjection(
+ initialValues,
+ partitionSize.toSeq)
+ val updateProjection = newMutableProjection(
+ updateExpressions,
+ aggBufferAttributes ++ inputAttributes)
+ val evaluateProjection = newMutableProjection(
+ evaluateExpressions,
+ aggBufferAttributes)
+
+ // Create the processor
+ new AggregateProcessor(
+ aggBufferAttributes.toArray,
+ initialProjection,
+ updateProjection,
+ evaluateProjection,
+ imperatives.toArray,
+ partitionSize.isDefined)
+ }
+}
+
+/**
+ * This class manages the processing of a number of aggregate functions. See the documentation of
+ * the object for more information.
+ */
+private[execution] final class AggregateProcessor(
+ private[this] val bufferSchema: Array[AttributeReference],
+ private[this] val initialProjection: MutableProjection,
+ private[this] val updateProjection: MutableProjection,
+ private[this] val evaluateProjection: MutableProjection,
+ private[this] val imperatives: Array[ImperativeAggregate],
+ private[this] val trackPartitionSize: Boolean) {
+
+ private[this] val join = new JoinedRow
+ private[this] val numImperatives = imperatives.length
+ private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType))
+ initialProjection.target(buffer)
+ updateProjection.target(buffer)
+
+ /** Create the initial state. */
+ def initialize(size: Int): Unit = {
+ // Some initialization expressions are dependent on the partition size so we have to
+ // initialize the size before initializing all other fields, and we have to pass the buffer to
+ // the initialization projection.
+ if (trackPartitionSize) {
+ buffer.setInt(0, size)
+ }
+ initialProjection(buffer)
+ var i = 0
+ while (i < numImperatives) {
+ imperatives(i).initialize(buffer)
+ i += 1
+ }
+ }
+
+ /** Update the buffer. */
+ def update(input: InternalRow): Unit = {
+ updateProjection(join(buffer, input))
+ var i = 0
+ while (i < numImperatives) {
+ imperatives(i).update(buffer, input)
+ i += 1
+ }
+ }
+
+ /** Evaluate buffer. */
+ def evaluate(target: MutableRow): Unit =
+ evaluateProjection.target(target)(buffer)
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
deleted file mode 100644
index 9fcfea8..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ /dev/null
@@ -1,111 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.aggregate
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.errors._
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-
-case class SortBasedAggregate(
- requiredChildDistributionExpressions: Option[Seq[Expression]],
- groupingExpressions: Seq[NamedExpression],
- aggregateExpressions: Seq[AggregateExpression],
- aggregateAttributes: Seq[Attribute],
- initialInputBufferOffset: Int,
- resultExpressions: Seq[NamedExpression],
- child: SparkPlan)
- extends UnaryNode {
-
- private[this] val aggregateBufferAttributes = {
- aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
- }
-
- override def producedAttributes: AttributeSet =
- AttributeSet(aggregateAttributes) ++
- AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
- AttributeSet(aggregateBufferAttributes)
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
-
- override def requiredChildDistribution: List[Distribution] = {
- requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
- case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
- case None => UnspecifiedDistribution :: Nil
- }
- }
-
- override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
- groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
- }
-
- override def outputOrdering: Seq[SortOrder] = {
- groupingExpressions.map(SortOrder(_, Ascending))
- }
-
- protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- val numOutputRows = longMetric("numOutputRows")
- child.execute().mapPartitionsInternal { iter =>
- // Because the constructor of an aggregation iterator will read at least the first row,
- // we need to get the value of iter.hasNext first.
- val hasInput = iter.hasNext
- if (!hasInput && groupingExpressions.nonEmpty) {
- // This is a grouped aggregate and the input iterator is empty,
- // so return an empty iterator.
- Iterator[UnsafeRow]()
- } else {
- val outputIter = new SortBasedAggregationIterator(
- groupingExpressions,
- child.output,
- iter,
- aggregateExpressions,
- aggregateAttributes,
- initialInputBufferOffset,
- resultExpressions,
- (expressions, inputSchema) =>
- newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
- numOutputRows)
- if (!hasInput && groupingExpressions.isEmpty) {
- // There is no input and there is no grouping expressions.
- // We need to output a single row as the output.
- numOutputRows += 1
- Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
- } else {
- outputIter
- }
- }
- }
- }
-
- override def simpleString: String = {
- val allAggregateExpressions = aggregateExpressions
-
- val keyString = groupingExpressions.mkString("[", ",", "]")
- val functionString = allAggregateExpressions.mkString("[", ",", "]")
- val outputString = output.mkString("[", ",", "]")
- s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)"
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
new file mode 100644
index 0000000..3169e0a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregateExec.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+case class SortBasedAggregateExec(
+ requiredChildDistributionExpressions: Option[Seq[Expression]],
+ groupingExpressions: Seq[NamedExpression],
+ aggregateExpressions: Seq[AggregateExpression],
+ aggregateAttributes: Seq[Attribute],
+ initialInputBufferOffset: Int,
+ resultExpressions: Seq[NamedExpression],
+ child: SparkPlan)
+ extends UnaryExecNode {
+
+ private[this] val aggregateBufferAttributes = {
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
+ override private[sql] lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+
+ override def requiredChildDistribution: List[Distribution] = {
+ requiredChildDistributionExpressions match {
+ case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
+ case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case None => UnspecifiedDistribution :: Nil
+ }
+ }
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
+ groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
+ }
+
+ override def outputOrdering: Seq[SortOrder] = {
+ groupingExpressions.map(SortOrder(_, Ascending))
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+ val numOutputRows = longMetric("numOutputRows")
+ child.execute().mapPartitionsInternal { iter =>
+ // Because the constructor of an aggregation iterator will read at least the first row,
+ // we need to get the value of iter.hasNext first.
+ val hasInput = iter.hasNext
+ if (!hasInput && groupingExpressions.nonEmpty) {
+ // This is a grouped aggregate and the input iterator is empty,
+ // so return an empty iterator.
+ Iterator[UnsafeRow]()
+ } else {
+ val outputIter = new SortBasedAggregationIterator(
+ groupingExpressions,
+ child.output,
+ iter,
+ aggregateExpressions,
+ aggregateAttributes,
+ initialInputBufferOffset,
+ resultExpressions,
+ (expressions, inputSchema) =>
+ newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled),
+ numOutputRows)
+ if (!hasInput && groupingExpressions.isEmpty) {
+ // There is no input and there is no grouping expressions.
+ // We need to output a single row as the output.
+ numOutputRows += 1
+ Iterator[UnsafeRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
+ } else {
+ outputIter
+ }
+ }
+ }
+ }
+
+ override def simpleString: String = {
+ val allAggregateExpressions = aggregateExpressions
+
+ val keyString = groupingExpressions.mkString("[", ",", "]")
+ val functionString = allAggregateExpressions.mkString("[", ",", "]")
+ val outputString = output.mkString("[", ",", "]")
+ s"SortBasedAggregate(key=$keyString, functions=$functionString, output=$outputString)"
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 49b682a..782da0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -38,7 +38,7 @@ case class TungstenAggregate(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends UnaryNode with CodegenSupport {
+ extends UnaryExecNode with CodegenSupport {
private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
index 4682949..f93c446 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave}
+import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec}
/**
* Utility functions used by the query planner to convert our plan to new aggregation code path.
@@ -35,7 +35,7 @@ object Utils {
val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete))
val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute)
- SortBasedAggregate(
+ SortBasedAggregateExec(
requiredChildDistributionExpressions = Some(groupingExpressions),
groupingExpressions = groupingExpressions,
aggregateExpressions = completeAggregateExpressions,
@@ -66,7 +66,7 @@ object Utils {
resultExpressions = resultExpressions,
child = child)
} else {
- SortBasedAggregate(
+ SortBasedAggregateExec(
requiredChildDistributionExpressions = requiredChildDistributionExpressions,
groupingExpressions = groupingExpressions,
aggregateExpressions = aggregateExpressions,
@@ -295,7 +295,7 @@ object Utils {
child = partialAggregate)
}
- val restored = StateStoreRestore(groupingAttributes, None, partialMerged1)
+ val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1)
val partialMerged2: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
@@ -312,7 +312,7 @@ object Utils {
child = restored)
}
- val saved = StateStoreSave(groupingAttributes, None, partialMerged2)
+ val saved = StateStoreSaveExec(groupingAttributes, None, partialMerged2)
val finalAndCompleteAggregate: SparkPlan = {
val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final))
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 892c57a..83f527f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,14 +20,15 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
-case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
- extends UnaryNode with CodegenSupport {
+/** Physical plan for Project. */
+case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
+ extends UnaryExecNode with CodegenSupport {
override def output: Seq[Attribute] = projectList.map(_.toAttribute)
@@ -74,8 +75,9 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
}
-case class Filter(condition: Expression, child: SparkPlan)
- extends UnaryNode with CodegenSupport with PredicateHelper {
+/** Physical plan for Filter. */
+case class FilterExec(condition: Expression, child: SparkPlan)
+ extends UnaryExecNode with CodegenSupport with PredicateHelper {
// Split out all the IsNotNulls from condition.
private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
@@ -209,7 +211,7 @@ case class Filter(condition: Expression, child: SparkPlan)
}
/**
- * Sample the dataset.
+ * Physical plan for sampling the dataset.
*
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
@@ -218,12 +220,12 @@ case class Filter(condition: Expression, child: SparkPlan)
* @param seed the random seed
* @param child the SparkPlan
*/
-case class Sample(
+case class SampleExec(
lowerBound: Double,
upperBound: Double,
withReplacement: Boolean,
seed: Long,
- child: SparkPlan) extends UnaryNode with CodegenSupport {
+ child: SparkPlan) extends UnaryExecNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
private[sql] override lazy val metrics = Map(
@@ -301,13 +303,23 @@ case class Sample(
}
}
-case class Range(
+
+/**
+ * Physical plan for range (generating a range of 64 bit numbers.
+ *
+ * @param start first number in the range, inclusive.
+ * @param step size of the step increment.
+ * @param numSlices number of partitions.
+ * @param numElements total number of elements to output.
+ * @param output output attributes.
+ */
+case class RangeExec(
start: Long,
step: Long,
numSlices: Int,
numElements: BigInt,
output: Seq[Attribute])
- extends LeafNode with CodegenSupport {
+ extends LeafExecNode with CodegenSupport {
private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -449,9 +461,9 @@ case class Range(
}
/**
- * Union two plans, without a distinct. This is UNION ALL in SQL.
+ * Physical plan for unioning two plans, without a distinct. This is UNION ALL in SQL.
*/
-case class Union(children: Seq[SparkPlan]) extends SparkPlan {
+case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
override def output: Seq[Attribute] =
children.map(_.output).transpose.map(attrs =>
attrs.head.withNullability(attrs.exists(_.nullable)))
@@ -461,12 +473,12 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan {
}
/**
- * Return a new RDD that has exactly `numPartitions` partitions.
+ * Physical plan for returning a new RDD that has exactly `numPartitions` partitions.
* Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
* if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of
* the 100 new partitions will claim 10 of the current partitions.
*/
-case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode {
+case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = {
@@ -480,10 +492,10 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode {
}
/**
- * Returns a table with the elements from left that are not in right using
+ * Physical plan for returning a table with the elements from left that are not in right using
* the built-in spark subtract function.
*/
-case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
+case class ExceptExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode {
override def output: Seq[Attribute] = left.output
protected override def doExecute(): RDD[InternalRow] = {
@@ -496,18 +508,18 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode {
* (hopefully structurally equivalent) tree from a different optimization sequence into an already
* resolved tree.
*/
-case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPlan {
+case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends SparkPlan {
def children: Seq[SparkPlan] = child :: Nil
protected override def doExecute(): RDD[InternalRow] = child.execute()
}
/**
- * A plan as subquery.
+ * Physical plan for a subquery.
*
* This is used to generate tree string for SparkScalarSubquery.
*/
-case class Subquery(name: String, child: SparkPlan) extends UnaryNode {
+case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] = child.output
override def outputPartitioning: Partitioning = child.outputPartitioning
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
deleted file mode 100644
index 1f964b1..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ /dev/null
@@ -1,358 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.columnar
-
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.{Accumulable, Accumulator, Accumulators}
-import org.apache.spark.network.util.JavaUtils
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.Statistics
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.UserDefinedType
-import org.apache.spark.storage.StorageLevel
-
-private[sql] object InMemoryRelation {
- def apply(
- useCompression: Boolean,
- batchSize: Int,
- storageLevel: StorageLevel,
- child: SparkPlan,
- tableName: Option[String]): InMemoryRelation =
- new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)()
-}
-
-/**
- * CachedBatch is a cached batch of rows.
- *
- * @param numRows The total number of rows in this batch
- * @param buffers The buffers for serialized columns
- * @param stats The stat of columns
- */
-private[columnar]
-case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
-
-private[sql] case class InMemoryRelation(
- output: Seq[Attribute],
- useCompression: Boolean,
- batchSize: Int,
- storageLevel: StorageLevel,
- @transient child: SparkPlan,
- tableName: Option[String])(
- @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
- @transient private[sql] var _statistics: Statistics = null,
- private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
- extends logical.LeafNode with MultiInstanceRelation {
-
- override def producedAttributes: AttributeSet = outputSet
-
- private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =
- if (_batchStats == null) {
- child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow])
- } else {
- _batchStats
- }
-
- @transient val partitionStatistics = new PartitionStatistics(output)
-
- private def computeSizeInBytes = {
- val sizeOfRow: Expression =
- BindReferences.bindReference(
- output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add),
- partitionStatistics.schema)
-
- batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
- }
-
- // Statistics propagation contracts:
- // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data
- // 2. Only propagate statistics when `_statistics` is non-null
- private def statisticsToBePropagated = if (_statistics == null) {
- val updatedStats = statistics
- if (_statistics == null) null else updatedStats
- } else {
- _statistics
- }
-
- override def statistics: Statistics = {
- if (_statistics == null) {
- if (batchStats.value.isEmpty) {
- // Underlying columnar RDD hasn't been materialized, no useful statistics information
- // available, return the default statistics.
- Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)
- } else {
- // Underlying columnar RDD has been materialized, required information has also been
- // collected via the `batchStats` accumulator, compute the final statistics,
- // and update `_statistics`.
- _statistics = Statistics(sizeInBytes = computeSizeInBytes)
- _statistics
- }
- } else {
- // Pre-computed statistics
- _statistics
- }
- }
-
- // If the cached column buffers were not passed in, we calculate them in the constructor.
- // As in Spark, the actual work of caching is lazy.
- if (_cachedColumnBuffers == null) {
- buildBuffers()
- }
-
- def recache(): Unit = {
- _cachedColumnBuffers.unpersist()
- _cachedColumnBuffers = null
- buildBuffers()
- }
-
- private def buildBuffers(): Unit = {
- val output = child.output
- val cached = child.execute().mapPartitionsInternal { rowIterator =>
- new Iterator[CachedBatch] {
- def next(): CachedBatch = {
- val columnBuilders = output.map { attribute =>
- ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression)
- }.toArray
-
- var rowCount = 0
- var totalSize = 0L
- while (rowIterator.hasNext && rowCount < batchSize
- && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) {
- val row = rowIterator.next()
-
- // Added for SPARK-6082. This assertion can be useful for scenarios when something
- // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM
- // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat
- // hard to decipher.
- assert(
- row.numFields == columnBuilders.length,
- s"Row column number mismatch, expected ${output.size} columns, " +
- s"but got ${row.numFields}." +
- s"\nRow content: $row")
-
- var i = 0
- totalSize = 0
- while (i < row.numFields) {
- columnBuilders(i).appendFrom(row, i)
- totalSize += columnBuilders(i).columnStats.sizeInBytes
- i += 1
- }
- rowCount += 1
- }
-
- val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
- .flatMap(_.values))
-
- batchStats += stats
- CachedBatch(rowCount, columnBuilders.map { builder =>
- JavaUtils.bufferToArray(builder.build())
- }, stats)
- }
-
- def hasNext: Boolean = rowIterator.hasNext
- }
- }.persist(storageLevel)
-
- cached.setName(tableName.map(n => s"In-memory table $n").getOrElse(child.toString))
- _cachedColumnBuffers = cached
- }
-
- def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
- InMemoryRelation(
- newOutput, useCompression, batchSize, storageLevel, child, tableName)(
- _cachedColumnBuffers, statisticsToBePropagated, batchStats)
- }
-
- override def newInstance(): this.type = {
- new InMemoryRelation(
- output.map(_.newInstance()),
- useCompression,
- batchSize,
- storageLevel,
- child,
- tableName)(
- _cachedColumnBuffers,
- statisticsToBePropagated,
- batchStats).asInstanceOf[this.type]
- }
-
- def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
-
- override protected def otherCopyArgs: Seq[AnyRef] =
- Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
-
- private[sql] def uncache(blocking: Boolean): Unit = {
- Accumulators.remove(batchStats.id)
- cachedColumnBuffers.unpersist(blocking)
- _cachedColumnBuffers = null
- }
-}
-
-private[sql] case class InMemoryColumnarTableScan(
- attributes: Seq[Attribute],
- predicates: Seq[Expression],
- @transient relation: InMemoryRelation)
- extends LeafNode {
-
- private[sql] override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- override def output: Seq[Attribute] = attributes
-
- // The cached version does not change the outputPartitioning of the original SparkPlan.
- override def outputPartitioning: Partitioning = relation.child.outputPartitioning
-
- // The cached version does not change the outputOrdering of the original SparkPlan.
- override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering
-
- private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
-
- // Returned filter predicate should return false iff it is impossible for the input expression
- // to evaluate to `true' based on statistics collected about this partition batch.
- @transient val buildFilter: PartialFunction[Expression, Expression] = {
- case And(lhs: Expression, rhs: Expression)
- if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) =>
- (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _)
-
- case Or(lhs: Expression, rhs: Expression)
- if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
- buildFilter(lhs) || buildFilter(rhs)
-
- case EqualTo(a: AttributeReference, l: Literal) =>
- statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
- case EqualTo(l: Literal, a: AttributeReference) =>
- statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
-
- case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
- case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
-
- case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
- case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
-
- case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
- case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
-
- case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
- case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
-
- case IsNull(a: Attribute) => statsFor(a).nullCount > 0
- case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
- }
-
- val partitionFilters: Seq[Expression] = {
- predicates.flatMap { p =>
- val filter = buildFilter.lift(p)
- val boundFilter =
- filter.map(
- BindReferences.bindReference(
- _,
- relation.partitionStatistics.schema,
- allowFailures = true))
-
- boundFilter.foreach(_ =>
- filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f")))
-
- // If the filter can't be resolved then we are missing required statistics.
- boundFilter.filter(_.resolved)
- }
- }
-
- lazy val enableAccumulators: Boolean =
- sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
-
- // Accumulators used for testing purposes
- lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
- lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
-
- private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
-
- protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- if (enableAccumulators) {
- readPartitions.setValue(0)
- readBatches.setValue(0)
- }
-
- // Using these variables here to avoid serialization of entire objects (if referenced directly)
- // within the map Partitions closure.
- val schema = relation.partitionStatistics.schema
- val schemaIndex = schema.zipWithIndex
- val relOutput = relation.output
- val buffers = relation.cachedColumnBuffers
-
- buffers.mapPartitionsInternal { cachedBatchIterator =>
- val partitionFilter = newPredicate(
- partitionFilters.reduceOption(And).getOrElse(Literal(true)),
- schema)
-
- // Find the ordinals and data types of the requested columns.
- val (requestedColumnIndices, requestedColumnDataTypes) =
- attributes.map { a =>
- relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType
- }.unzip
-
- // Do partition batch pruning if enabled
- val cachedBatchesToScan =
- if (inMemoryPartitionPruningEnabled) {
- cachedBatchIterator.filter { cachedBatch =>
- if (!partitionFilter(cachedBatch.stats)) {
- def statsString: String = schemaIndex.map {
- case (a, i) =>
- val value = cachedBatch.stats.get(i, a.dataType)
- s"${a.name}: $value"
- }.mkString(", ")
- logInfo(s"Skipping partition based on stats $statsString")
- false
- } else {
- if (enableAccumulators) {
- readBatches += 1
- }
- true
- }
- }
- } else {
- cachedBatchIterator
- }
-
- // update SQL metrics
- val withMetrics = cachedBatchesToScan.map { batch =>
- numOutputRows += batch.numRows
- batch
- }
-
- val columnTypes = requestedColumnDataTypes.map {
- case udt: UserDefinedType[_] => udt.sqlType
- case other => other
- }.toArray
- val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
- columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
- if (enableAccumulators && columnarIterator.hasNext) {
- readPartitions += 1
- }
- columnarIterator
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
new file mode 100644
index 0000000..cb957b9
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -0,0 +1,358 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.columnar
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.{Accumulable, Accumulator, Accumulators}
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.UserDefinedType
+import org.apache.spark.storage.StorageLevel
+
+private[sql] object InMemoryRelation {
+ def apply(
+ useCompression: Boolean,
+ batchSize: Int,
+ storageLevel: StorageLevel,
+ child: SparkPlan,
+ tableName: Option[String]): InMemoryRelation =
+ new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)()
+}
+
+/**
+ * CachedBatch is a cached batch of rows.
+ *
+ * @param numRows The total number of rows in this batch
+ * @param buffers The buffers for serialized columns
+ * @param stats The stat of columns
+ */
+private[columnar]
+case class CachedBatch(numRows: Int, buffers: Array[Array[Byte]], stats: InternalRow)
+
+private[sql] case class InMemoryRelation(
+ output: Seq[Attribute],
+ useCompression: Boolean,
+ batchSize: Int,
+ storageLevel: StorageLevel,
+ @transient child: SparkPlan,
+ tableName: Option[String])(
+ @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
+ @transient private[sql] var _statistics: Statistics = null,
+ private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
+ extends logical.LeafNode with MultiInstanceRelation {
+
+ override def producedAttributes: AttributeSet = outputSet
+
+ private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =
+ if (_batchStats == null) {
+ child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow])
+ } else {
+ _batchStats
+ }
+
+ @transient val partitionStatistics = new PartitionStatistics(output)
+
+ private def computeSizeInBytes = {
+ val sizeOfRow: Expression =
+ BindReferences.bindReference(
+ output.map(a => partitionStatistics.forAttribute(a).sizeInBytes).reduce(Add),
+ partitionStatistics.schema)
+
+ batchStats.value.map(row => sizeOfRow.eval(row).asInstanceOf[Long]).sum
+ }
+
+ // Statistics propagation contracts:
+ // 1. Non-null `_statistics` must reflect the actual statistics of the underlying data
+ // 2. Only propagate statistics when `_statistics` is non-null
+ private def statisticsToBePropagated = if (_statistics == null) {
+ val updatedStats = statistics
+ if (_statistics == null) null else updatedStats
+ } else {
+ _statistics
+ }
+
+ override def statistics: Statistics = {
+ if (_statistics == null) {
+ if (batchStats.value.isEmpty) {
+ // Underlying columnar RDD hasn't been materialized, no useful statistics information
+ // available, return the default statistics.
+ Statistics(sizeInBytes = child.sqlContext.conf.defaultSizeInBytes)
+ } else {
+ // Underlying columnar RDD has been materialized, required information has also been
+ // collected via the `batchStats` accumulator, compute the final statistics,
+ // and update `_statistics`.
+ _statistics = Statistics(sizeInBytes = computeSizeInBytes)
+ _statistics
+ }
+ } else {
+ // Pre-computed statistics
+ _statistics
+ }
+ }
+
+ // If the cached column buffers were not passed in, we calculate them in the constructor.
+ // As in Spark, the actual work of caching is lazy.
+ if (_cachedColumnBuffers == null) {
+ buildBuffers()
+ }
+
+ def recache(): Unit = {
+ _cachedColumnBuffers.unpersist()
+ _cachedColumnBuffers = null
+ buildBuffers()
+ }
+
+ private def buildBuffers(): Unit = {
+ val output = child.output
+ val cached = child.execute().mapPartitionsInternal { rowIterator =>
+ new Iterator[CachedBatch] {
+ def next(): CachedBatch = {
+ val columnBuilders = output.map { attribute =>
+ ColumnBuilder(attribute.dataType, batchSize, attribute.name, useCompression)
+ }.toArray
+
+ var rowCount = 0
+ var totalSize = 0L
+ while (rowIterator.hasNext && rowCount < batchSize
+ && totalSize < ColumnBuilder.MAX_BATCH_SIZE_IN_BYTE) {
+ val row = rowIterator.next()
+
+ // Added for SPARK-6082. This assertion can be useful for scenarios when something
+ // like Hive TRANSFORM is used. The external data generation script used in TRANSFORM
+ // may result malformed rows, causing ArrayIndexOutOfBoundsException, which is somewhat
+ // hard to decipher.
+ assert(
+ row.numFields == columnBuilders.length,
+ s"Row column number mismatch, expected ${output.size} columns, " +
+ s"but got ${row.numFields}." +
+ s"\nRow content: $row")
+
+ var i = 0
+ totalSize = 0
+ while (i < row.numFields) {
+ columnBuilders(i).appendFrom(row, i)
+ totalSize += columnBuilders(i).columnStats.sizeInBytes
+ i += 1
+ }
+ rowCount += 1
+ }
+
+ val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics)
+ .flatMap(_.values))
+
+ batchStats += stats
+ CachedBatch(rowCount, columnBuilders.map { builder =>
+ JavaUtils.bufferToArray(builder.build())
+ }, stats)
+ }
+
+ def hasNext: Boolean = rowIterator.hasNext
+ }
+ }.persist(storageLevel)
+
+ cached.setName(tableName.map(n => s"In-memory table $n").getOrElse(child.toString))
+ _cachedColumnBuffers = cached
+ }
+
+ def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = {
+ InMemoryRelation(
+ newOutput, useCompression, batchSize, storageLevel, child, tableName)(
+ _cachedColumnBuffers, statisticsToBePropagated, batchStats)
+ }
+
+ override def newInstance(): this.type = {
+ new InMemoryRelation(
+ output.map(_.newInstance()),
+ useCompression,
+ batchSize,
+ storageLevel,
+ child,
+ tableName)(
+ _cachedColumnBuffers,
+ statisticsToBePropagated,
+ batchStats).asInstanceOf[this.type]
+ }
+
+ def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers
+
+ override protected def otherCopyArgs: Seq[AnyRef] =
+ Seq(_cachedColumnBuffers, statisticsToBePropagated, batchStats)
+
+ private[sql] def uncache(blocking: Boolean): Unit = {
+ Accumulators.remove(batchStats.id)
+ cachedColumnBuffers.unpersist(blocking)
+ _cachedColumnBuffers = null
+ }
+}
+
+private[sql] case class InMemoryTableScanExec(
+ attributes: Seq[Attribute],
+ predicates: Seq[Expression],
+ @transient relation: InMemoryRelation)
+ extends LeafExecNode {
+
+ private[sql] override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def output: Seq[Attribute] = attributes
+
+ // The cached version does not change the outputPartitioning of the original SparkPlan.
+ override def outputPartitioning: Partitioning = relation.child.outputPartitioning
+
+ // The cached version does not change the outputOrdering of the original SparkPlan.
+ override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering
+
+ private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
+
+ // Returned filter predicate should return false iff it is impossible for the input expression
+ // to evaluate to `true' based on statistics collected about this partition batch.
+ @transient val buildFilter: PartialFunction[Expression, Expression] = {
+ case And(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) || buildFilter.isDefinedAt(rhs) =>
+ (buildFilter.lift(lhs) ++ buildFilter.lift(rhs)).reduce(_ && _)
+
+ case Or(lhs: Expression, rhs: Expression)
+ if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) =>
+ buildFilter(lhs) || buildFilter(rhs)
+
+ case EqualTo(a: AttributeReference, l: Literal) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+ case EqualTo(l: Literal, a: AttributeReference) =>
+ statsFor(a).lowerBound <= l && l <= statsFor(a).upperBound
+
+ case LessThan(a: AttributeReference, l: Literal) => statsFor(a).lowerBound < l
+ case LessThan(l: Literal, a: AttributeReference) => l < statsFor(a).upperBound
+
+ case LessThanOrEqual(a: AttributeReference, l: Literal) => statsFor(a).lowerBound <= l
+ case LessThanOrEqual(l: Literal, a: AttributeReference) => l <= statsFor(a).upperBound
+
+ case GreaterThan(a: AttributeReference, l: Literal) => l < statsFor(a).upperBound
+ case GreaterThan(l: Literal, a: AttributeReference) => statsFor(a).lowerBound < l
+
+ case GreaterThanOrEqual(a: AttributeReference, l: Literal) => l <= statsFor(a).upperBound
+ case GreaterThanOrEqual(l: Literal, a: AttributeReference) => statsFor(a).lowerBound <= l
+
+ case IsNull(a: Attribute) => statsFor(a).nullCount > 0
+ case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0
+ }
+
+ val partitionFilters: Seq[Expression] = {
+ predicates.flatMap { p =>
+ val filter = buildFilter.lift(p)
+ val boundFilter =
+ filter.map(
+ BindReferences.bindReference(
+ _,
+ relation.partitionStatistics.schema,
+ allowFailures = true))
+
+ boundFilter.foreach(_ =>
+ filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f")))
+
+ // If the filter can't be resolved then we are missing required statistics.
+ boundFilter.filter(_.resolved)
+ }
+ }
+
+ lazy val enableAccumulators: Boolean =
+ sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
+
+ // Accumulators used for testing purposes
+ lazy val readPartitions: Accumulator[Int] = sparkContext.accumulator(0)
+ lazy val readBatches: Accumulator[Int] = sparkContext.accumulator(0)
+
+ private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+
+ if (enableAccumulators) {
+ readPartitions.setValue(0)
+ readBatches.setValue(0)
+ }
+
+ // Using these variables here to avoid serialization of entire objects (if referenced directly)
+ // within the map Partitions closure.
+ val schema = relation.partitionStatistics.schema
+ val schemaIndex = schema.zipWithIndex
+ val relOutput = relation.output
+ val buffers = relation.cachedColumnBuffers
+
+ buffers.mapPartitionsInternal { cachedBatchIterator =>
+ val partitionFilter = newPredicate(
+ partitionFilters.reduceOption(And).getOrElse(Literal(true)),
+ schema)
+
+ // Find the ordinals and data types of the requested columns.
+ val (requestedColumnIndices, requestedColumnDataTypes) =
+ attributes.map { a =>
+ relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType
+ }.unzip
+
+ // Do partition batch pruning if enabled
+ val cachedBatchesToScan =
+ if (inMemoryPartitionPruningEnabled) {
+ cachedBatchIterator.filter { cachedBatch =>
+ if (!partitionFilter(cachedBatch.stats)) {
+ def statsString: String = schemaIndex.map {
+ case (a, i) =>
+ val value = cachedBatch.stats.get(i, a.dataType)
+ s"${a.name}: $value"
+ }.mkString(", ")
+ logInfo(s"Skipping partition based on stats $statsString")
+ false
+ } else {
+ if (enableAccumulators) {
+ readBatches += 1
+ }
+ true
+ }
+ }
+ } else {
+ cachedBatchIterator
+ }
+
+ // update SQL metrics
+ val withMetrics = cachedBatchesToScan.map { batch =>
+ numOutputRows += batch.numRows
+ batch
+ }
+
+ val columnTypes = requestedColumnDataTypes.map {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }.toArray
+ val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
+ columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
+ if (enableAccumulators && columnarIterator.hasNext) {
+ readPartitions += 1
+ }
+ columnarIterator
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org