You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/16 14:41:49 UTC
spark git commit: [SPARK-22392][SQL] data source v2 columnar batch
reader
Repository: spark
Updated Branches:
refs/heads/master b85eb946a -> 75db14864
[SPARK-22392][SQL] data source v2 columnar batch reader
## What changes were proposed in this pull request?
a new Data Source V2 interface to allow the data source to return `ColumnarBatch` during the scan.
## How was this patch tested?
new tests
Author: Wenchen Fan <we...@databricks.com>
Closes #20153 from cloud-fan/columnar-reader.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/75db1486
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/75db1486
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/75db1486
Branch: refs/heads/master
Commit: 75db14864d2bd9b8e13154226e94d466e3a7e0a0
Parents: b85eb94
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Jan 16 22:41:30 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jan 16 22:41:30 2018 +0800
----------------------------------------------------------------------
.../sources/v2/reader/DataSourceV2Reader.java | 5 +-
.../v2/reader/SupportsScanColumnarBatch.java | 52 +++++++++
.../v2/reader/SupportsScanUnsafeRow.java | 2 +-
.../spark/sql/execution/ColumnarBatchScan.scala | 37 +++++-
.../sql/execution/DataSourceScanExec.scala | 39 ++-----
.../columnar/InMemoryTableScanExec.scala | 101 +++++++++--------
.../datasources/v2/DataSourceRDD.scala | 20 ++--
.../datasources/v2/DataSourceV2ScanExec.scala | 72 +++++++-----
.../ContinuousDataSourceRDDIter.scala | 4 +-
.../sql/sources/v2/JavaBatchDataSourceV2.java | 112 +++++++++++++++++++
.../sql/execution/WholeStageCodegenSuite.scala | 28 ++---
.../sql/sources/v2/DataSourceV2Suite.scala | 72 +++++++++++-
12 files changed, 400 insertions(+), 144 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
index 95ee4a8..f23c384 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
@@ -38,7 +38,10 @@ import org.apache.spark.sql.types.StructType;
* 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc.
* Names of these interfaces start with `SupportsReporting`.
* 3. Special scans. E.g, columnar scan, unsafe row scan, etc.
- * Names of these interfaces start with `SupportsScan`.
+ * Names of these interfaces start with `SupportsScan`. Note that a reader should only
+ * implement at most one of the special scans, if more than one special scans are implemented,
+ * only one of them would be respected, according to the priority list from high to low:
+ * {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
*
* If an exception was throw when applying any of these query optimizations, the action would fail
* and no Spark job was submitted.
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
new file mode 100644
index 0000000..27cf3a7
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import java.util.List;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+
+/**
+ * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this
+ * interface to output {@link ColumnarBatch} and make the scan faster.
+ */
+@InterfaceStability.Evolving
+public interface SupportsScanColumnarBatch extends DataSourceV2Reader {
+ @Override
+ default List<ReadTask<Row>> createReadTasks() {
+ throw new IllegalStateException(
+ "createReadTasks not supported by default within SupportsScanColumnarBatch.");
+ }
+
+ /**
+ * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches.
+ */
+ List<ReadTask<ColumnarBatch>> createBatchReadTasks();
+
+ /**
+ * Returns true if the concrete data source reader can read data in batch according to the scan
+ * properties like required columns, pushes filters, etc. It's possible that the implementation
+ * can only support some certain columns with certain types. Users can overwrite this method and
+ * {@link #createReadTasks()} to fallback to normal read path under some conditions.
+ */
+ default boolean enableBatchRead() {
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
index b90ec88..2d3ad0e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
@@ -35,7 +35,7 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader {
@Override
default List<ReadTask<Row>> createReadTasks() {
throw new IllegalStateException(
- "createReadTasks should not be called with SupportsScanUnsafeRow.");
+ "createReadTasks not supported by default within SupportsScanUnsafeRow");
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 5617046..dd68df9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.DataType
@@ -25,13 +25,16 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
/**
- * Helper trait for abstracting scan functionality using
- * [[ColumnarBatch]]es.
+ * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es.
*/
private[sql] trait ColumnarBatchScan extends CodegenSupport {
def vectorTypes: Option[Seq[String]] = None
+ protected def supportsBatch: Boolean = true
+
+ protected def needsUnsafeRowConversion: Boolean = true
+
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
@@ -71,7 +74,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
// PhysicalRDD always just has one input
val input = ctx.addMutableState("scala.collection.Iterator", "input",
v => s"$v = inputs[0];")
+ if (supportsBatch) {
+ produceBatches(ctx, input)
+ } else {
+ produceRows(ctx, input)
+ }
+ }
+ private def produceBatches(ctx: CodegenContext, input: String): String = {
// metrics
val numOutputRows = metricTerm(ctx, "numOutputRows")
val scanTimeMetric = metricTerm(ctx, "scanTime")
@@ -137,4 +147,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
""".stripMargin
}
+ private def produceRows(ctx: CodegenContext, input: String): String = {
+ val numOutputRows = metricTerm(ctx, "numOutputRows")
+ val row = ctx.freshName("row")
+
+ ctx.INPUT_ROW = row
+ ctx.currentVars = null
+ // Always provide `outputVars`, so that the framework can help us build unsafe row if the input
+ // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true.
+ val outputVars = output.zipWithIndex.map { case (a, i) =>
+ BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ }
+ val inputRow = if (needsUnsafeRowConversion) null else row
+ s"""
+ |while ($input.hasNext()) {
+ | InternalRow $row = (InternalRow) $input.next();
+ | $numOutputRows.add(1);
+ | ${consume(ctx, outputVars, inputRow).trim}
+ | if (shouldStop()) return;
+ |}
+ """.stripMargin
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index d1ff82c..7c7d79c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -164,13 +164,15 @@ case class FileSourceScanExec(
override val tableIdentifier: Option[TableIdentifier])
extends DataSourceScanExec with ColumnarBatchScan {
- val supportsBatch: Boolean = relation.fileFormat.supportBatch(
+ override val supportsBatch: Boolean = relation.fileFormat.supportBatch(
relation.sparkSession, StructType.fromAttributes(output))
- val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) {
- SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled
- } else {
- false
+ override val needsUnsafeRowConversion: Boolean = {
+ if (relation.fileFormat.isInstanceOf[ParquetSource]) {
+ SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled
+ } else {
+ false
+ }
}
override def vectorTypes: Option[Seq[String]] =
@@ -346,33 +348,6 @@ case class FileSourceScanExec(
override val nodeNamePrefix: String = "File"
- override protected def doProduce(ctx: CodegenContext): String = {
- if (supportsBatch) {
- return super.doProduce(ctx)
- }
- val numOutputRows = metricTerm(ctx, "numOutputRows")
- // PhysicalRDD always just has one input
- val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
- val row = ctx.freshName("row")
-
- ctx.INPUT_ROW = row
- ctx.currentVars = null
- // Always provide `outputVars`, so that the framework can help us build unsafe row if the input
- // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true.
- val outputVars = output.zipWithIndex.map{ case (a, i) =>
- BoundReference(i, a.dataType, a.nullable).genCode(ctx)
- }
- val inputRow = if (needsUnsafeRowConversion) null else row
- s"""
- |while ($input.hasNext()) {
- | InternalRow $row = (InternalRow) $input.next();
- | $numOutputRows.add(1);
- | ${consume(ctx, outputVars, inputRow).trim}
- | if (shouldStop()) return;
- |}
- """.stripMargin
- }
-
/**
* Create an RDD for bucketed reads.
* The non-bucketed variant of this function is [[createNonBucketedReadRDD]].
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 933b975..3565ee3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -49,9 +49,9 @@ case class InMemoryTableScanExec(
/**
* If true, get data from ColumnVector in ColumnarBatch, which are generally faster.
- * If false, get data from UnsafeRow build from ColumnVector
+ * If false, get data from UnsafeRow build from CachedBatch
*/
- override val supportCodegen: Boolean = {
+ override val supportsBatch: Boolean = {
// In the initial implementation, for ease of review
// support only primitive data types and # of fields is less than wholeStageMaxNumFields
relation.schema.fields.forall(f => f.dataType match {
@@ -61,6 +61,8 @@ case class InMemoryTableScanExec(
}) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema)
}
+ override protected def needsUnsafeRowConversion: Boolean = false
+
private val columnIndices =
attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray
@@ -90,14 +92,56 @@ case class InMemoryTableScanExec(
columnarBatch
}
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- assert(supportCodegen)
+ private lazy val inputRDD: RDD[InternalRow] = {
val buffers = filteredCachedBatches()
- // HACK ALERT: This is actually an RDD[ColumnarBatch].
- // We're taking advantage of Scala's type erasure here to pass these batches along.
- Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]])
+ if (supportsBatch) {
+ // HACK ALERT: This is actually an RDD[ColumnarBatch].
+ // We're taking advantage of Scala's type erasure here to pass these batches along.
+ buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]]
+ } else {
+ val numOutputRows = longMetric("numOutputRows")
+
+ if (enableAccumulatorsForTest) {
+ readPartitions.setValue(0)
+ readBatches.setValue(0)
+ }
+
+ // Using these variables here to avoid serialization of entire objects (if referenced
+ // directly) within the map Partitions closure.
+ val relOutput: AttributeSeq = relation.output
+
+ filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator =>
+ // Find the ordinals and data types of the requested columns.
+ val (requestedColumnIndices, requestedColumnDataTypes) =
+ attributes.map { a =>
+ relOutput.indexOf(a.exprId) -> a.dataType
+ }.unzip
+
+ // update SQL metrics
+ val withMetrics = cachedBatchIterator.map { batch =>
+ if (enableAccumulatorsForTest) {
+ readBatches.add(1)
+ }
+ numOutputRows += batch.numRows
+ batch
+ }
+
+ val columnTypes = requestedColumnDataTypes.map {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case other => other
+ }.toArray
+ val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
+ columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
+ if (enableAccumulatorsForTest && columnarIterator.hasNext) {
+ readPartitions.add(1)
+ }
+ columnarIterator
+ }
+ }
}
+ override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
+
override def output: Seq[Attribute] = attributes
private def updateAttribute(expr: Expression): Expression = {
@@ -185,7 +229,7 @@ case class InMemoryTableScanExec(
}
}
- lazy val enableAccumulators: Boolean =
+ lazy val enableAccumulatorsForTest: Boolean =
sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
// Accumulators used for testing purposes
@@ -230,43 +274,10 @@ case class InMemoryTableScanExec(
}
protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- if (enableAccumulators) {
- readPartitions.setValue(0)
- readBatches.setValue(0)
- }
-
- // Using these variables here to avoid serialization of entire objects (if referenced directly)
- // within the map Partitions closure.
- val relOutput: AttributeSeq = relation.output
-
- filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator =>
- // Find the ordinals and data types of the requested columns.
- val (requestedColumnIndices, requestedColumnDataTypes) =
- attributes.map { a =>
- relOutput.indexOf(a.exprId) -> a.dataType
- }.unzip
-
- // update SQL metrics
- val withMetrics = cachedBatchIterator.map { batch =>
- if (enableAccumulators) {
- readBatches.add(1)
- }
- numOutputRows += batch.numRows
- batch
- }
-
- val columnTypes = requestedColumnDataTypes.map {
- case udt: UserDefinedType[_] => udt.sqlType
- case other => other
- }.toArray
- val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
- columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
- if (enableAccumulators && columnarIterator.hasNext) {
- readPartitions.add(1)
- }
- columnarIterator
+ if (supportsBatch) {
+ WholeStageCodegenExec(this).execute()
+ } else {
+ inputRDD
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 5f30be5..ac104d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -18,19 +18,19 @@
package org.apache.spark.sql.execution.datasources.v2
import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.sources.v2.reader.ReadTask
-class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow])
+class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T])
extends Partition with Serializable
-class DataSourceRDD(
+class DataSourceRDD[T: ClassTag](
sc: SparkContext,
- @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]])
- extends RDD[UnsafeRow](sc, Nil) {
+ @transient private val readTasks: java.util.List[ReadTask[T]])
+ extends RDD[T](sc, Nil) {
override protected def getPartitions: Array[Partition] = {
readTasks.asScala.zipWithIndex.map {
@@ -38,10 +38,10 @@ class DataSourceRDD(
}.toArray
}
- override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
- val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader()
+ override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+ val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader()
context.addTaskCompletionListener(_ => reader.close())
- val iter = new Iterator[UnsafeRow] {
+ val iter = new Iterator[T] {
private[this] var valuePrepared = false
override def hasNext: Boolean = {
@@ -51,7 +51,7 @@ class DataSourceRDD(
valuePrepared
}
- override def next(): UnsafeRow = {
+ override def next(): T = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
@@ -63,6 +63,6 @@ class DataSourceRDD(
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations()
+ split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 49c506b..8c64df0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -24,10 +24,8 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.LeafExecNode
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.streaming.StreamExecution
-import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions}
+import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader
import org.apache.spark.sql.types.StructType
@@ -37,40 +35,56 @@ import org.apache.spark.sql.types.StructType
*/
case class DataSourceV2ScanExec(
fullOutput: Seq[AttributeReference],
- @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder {
+ @transient reader: DataSourceV2Reader)
+ extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {
override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]
- override def references: AttributeSet = AttributeSet.empty
+ override def producedAttributes: AttributeSet = AttributeSet(fullOutput)
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+ private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match {
+ case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks()
+ case _ =>
+ reader.createReadTasks().asScala.map {
+ new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow]
+ }.asJava
+ }
- override protected def doExecute(): RDD[InternalRow] = {
- val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match {
- case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks()
- case _ =>
- reader.createReadTasks().asScala.map {
- new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow]
- }.asJava
- }
+ private lazy val inputRDD: RDD[InternalRow] = reader match {
+ case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
+ assert(!reader.isInstanceOf[ContinuousReader],
+ "continuous stream reader does not support columnar read yet.")
+ new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]]
+
+ case _: ContinuousReader =>
+ EpochCoordinatorRef.get(
+ sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
+ .askSync[Unit](SetReaderPartitions(readTasks.size()))
+ new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
+ .asInstanceOf[RDD[InternalRow]]
+
+ case _ =>
+ new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]]
+ }
- val inputRDD = reader match {
- case _: ContinuousReader =>
- EpochCoordinatorRef.get(
- sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
- .askSync[Unit](SetReaderPartitions(readTasks.size()))
+ override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
- new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
+ override val supportsBatch: Boolean = reader match {
+ case r: SupportsScanColumnarBatch if r.enableBatchRead() => true
+ case _ => false
+ }
- case _ =>
- new DataSourceRDD(sparkContext, readTasks)
- }
+ override protected def needsUnsafeRowConversion: Boolean = false
- val numOutputRows = longMetric("numOutputRows")
- inputRDD.asInstanceOf[RDD[InternalRow]].map { r =>
- numOutputRows += 1
- r
+ override protected def doExecute(): RDD[InternalRow] = {
+ if (supportsBatch) {
+ WholeStageCodegenExec(this).execute()
+ } else {
+ val numOutputRows = longMetric("numOutputRows")
+ inputRDD.map { r =>
+ numOutputRows += 1
+ r
+ }
}
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index d79e4bd..b3f1a1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -52,7 +52,7 @@ class ContinuousDataSourceRDD(
}
override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
- val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader()
+ val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader()
val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
@@ -132,7 +132,7 @@ class ContinuousDataSourceRDD(
}
override def getPreferredLocations(split: Partition): Seq[String] = {
- split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations()
+ split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations()
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
new file mode 100644
index 0000000..44e5146
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources.v2;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.ReadSupport;
+import org.apache.spark.sql.sources.v2.reader.*;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+
+
+public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport {
+
+ class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch {
+ private final StructType schema = new StructType().add("i", "int").add("j", "int");
+
+ @Override
+ public StructType readSchema() {
+ return schema;
+ }
+
+ @Override
+ public List<ReadTask<ColumnarBatch>> createBatchReadTasks() {
+ return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90));
+ }
+ }
+
+ static class JavaBatchReadTask implements ReadTask<ColumnarBatch>, DataReader<ColumnarBatch> {
+ private int start;
+ private int end;
+
+ private static final int BATCH_SIZE = 20;
+
+ private OnHeapColumnVector i;
+ private OnHeapColumnVector j;
+ private ColumnarBatch batch;
+
+ JavaBatchReadTask(int start, int end) {
+ this.start = start;
+ this.end = end;
+ }
+
+ @Override
+ public DataReader<ColumnarBatch> createDataReader() {
+ this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
+ this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
+ ColumnVector[] vectors = new ColumnVector[2];
+ vectors[0] = i;
+ vectors[1] = j;
+ this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE);
+ return this;
+ }
+
+ @Override
+ public boolean next() {
+ i.reset();
+ j.reset();
+ int count = 0;
+ while (start < end && count < BATCH_SIZE) {
+ i.putInt(count, start);
+ j.putInt(count, -start);
+ start += 1;
+ count += 1;
+ }
+
+ if (count == 0) {
+ return false;
+ } else {
+ batch.setNumRows(count);
+ return true;
+ }
+ }
+
+ @Override
+ public ColumnarBatch get() {
+ return batch;
+ }
+
+ @Override
+ public void close() throws IOException {
+ batch.close();
+ }
+ }
+
+
+ @Override
+ public DataSourceV2Reader createReader(DataSourceV2Options options) {
+ return new Reader();
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index bc05dca..22ca128 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -121,31 +121,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") {
import testImplicits._
- val dsInt = spark.range(3).cache
- dsInt.count
+ val dsInt = spark.range(3).cache()
+ dsInt.count()
val dsIntFilter = dsInt.filter(_ > 0)
val planInt = dsIntFilter.queryExecution.executedPlan
- assert(planInt.find(p =>
- p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
- .isInstanceOf[InMemoryTableScanExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
- .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined
- )
+ assert(planInt.collect {
+ case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => ()
+ }.length == 1)
assert(dsIntFilter.collect() === Array(1, 2))
// cache for string type is not supported for InMemoryTableScanExec
- val dsString = spark.range(3).map(_.toString).cache
- dsString.count
+ val dsString = spark.range(3).map(_.toString).cache()
+ dsString.count()
val dsStringFilter = dsString.filter(_ == "1")
val planString = dsStringFilter.queryExecution.executedPlan
- assert(planString.find(p =>
- p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
- !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
- .isInstanceOf[InMemoryTableScanExec]).isDefined
- )
+ assert(planString.collect {
+ case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => ()
+ }.length == 1)
assert(dsStringFilter.collect() === Array("1"))
}
http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index ab37e49..a89f7c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -24,10 +24,12 @@ import test.org.apache.spark.sql.sources.v2._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.sources.{Filter, GreaterThan}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.vectorized.ColumnarBatch
class DataSourceV2Suite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -56,7 +58,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
- test("unsafe row implementation") {
+ test("unsafe row scan implementation") {
Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
val df = spark.read.format(cls.getName).load()
@@ -67,6 +69,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
+ test("columnar batch scan implementation") {
+ Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls =>
+ withClue(cls.getName) {
+ val df = spark.read.format(cls.getName).load()
+ checkAnswer(df, (0 until 90).map(i => Row(i, -i)))
+ checkAnswer(df.select('j), (0 until 90).map(i => Row(-i)))
+ checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i)))
+ }
+ }
+ }
+
test("schema required data source") {
Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls =>
withClue(cls.getName) {
@@ -275,7 +288,7 @@ class UnsafeRowReadTask(start: Int, end: Int)
private var current = start - 1
- override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end)
+ override def createDataReader(): DataReader[UnsafeRow] = this
override def next(): Boolean = {
current += 1
@@ -300,3 +313,56 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema {
override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader =
new Reader(schema)
}
+
+class BatchDataSourceV2 extends DataSourceV2 with ReadSupport {
+
+ class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch {
+ override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
+
+ override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = {
+ java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90))
+ }
+ }
+
+ override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
+}
+
+class BatchReadTask(start: Int, end: Int)
+ extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] {
+
+ private final val BATCH_SIZE = 20
+ private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
+ private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
+ private lazy val batch = new ColumnarBatch(
+ new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE)
+
+ private var current = start
+
+ override def createDataReader(): DataReader[ColumnarBatch] = this
+
+ override def next(): Boolean = {
+ i.reset()
+ j.reset()
+
+ var count = 0
+ while (current < end && count < BATCH_SIZE) {
+ i.putInt(count, current)
+ j.putInt(count, -current)
+ current += 1
+ count += 1
+ }
+
+ if (count == 0) {
+ false
+ } else {
+ batch.setNumRows(count)
+ true
+ }
+ }
+
+ override def get(): ColumnarBatch = {
+ batch
+ }
+
+ override def close(): Unit = batch.close()
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org