You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/01/16 14:41:49 UTC

spark git commit: [SPARK-22392][SQL] data source v2 columnar batch reader

Repository: spark
Updated Branches:
  refs/heads/master b85eb946a -> 75db14864


[SPARK-22392][SQL] data source v2 columnar batch reader

## What changes were proposed in this pull request?

a new Data Source V2 interface to allow the data source to return `ColumnarBatch` during the scan.

## How was this patch tested?

new tests

Author: Wenchen Fan <we...@databricks.com>

Closes #20153 from cloud-fan/columnar-reader.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/75db1486
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/75db1486
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/75db1486

Branch: refs/heads/master
Commit: 75db14864d2bd9b8e13154226e94d466e3a7e0a0
Parents: b85eb94
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Jan 16 22:41:30 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Jan 16 22:41:30 2018 +0800

----------------------------------------------------------------------
 .../sources/v2/reader/DataSourceV2Reader.java   |   5 +-
 .../v2/reader/SupportsScanColumnarBatch.java    |  52 +++++++++
 .../v2/reader/SupportsScanUnsafeRow.java        |   2 +-
 .../spark/sql/execution/ColumnarBatchScan.scala |  37 +++++-
 .../sql/execution/DataSourceScanExec.scala      |  39 ++-----
 .../columnar/InMemoryTableScanExec.scala        | 101 +++++++++--------
 .../datasources/v2/DataSourceRDD.scala          |  20 ++--
 .../datasources/v2/DataSourceV2ScanExec.scala   |  72 +++++++-----
 .../ContinuousDataSourceRDDIter.scala           |   4 +-
 .../sql/sources/v2/JavaBatchDataSourceV2.java   | 112 +++++++++++++++++++
 .../sql/execution/WholeStageCodegenSuite.scala  |  28 ++---
 .../sql/sources/v2/DataSourceV2Suite.scala      |  72 +++++++++++-
 12 files changed, 400 insertions(+), 144 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
index 95ee4a8..f23c384 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java
@@ -38,7 +38,10 @@ import org.apache.spark.sql.types.StructType;
  *   2. Information Reporting. E.g., statistics reporting, ordering reporting, etc.
  *      Names of these interfaces start with `SupportsReporting`.
  *   3. Special scans. E.g, columnar scan, unsafe row scan, etc.
- *      Names of these interfaces start with `SupportsScan`.
+ *      Names of these interfaces start with `SupportsScan`. Note that a reader should only
+ *      implement at most one of the special scans, if more than one special scans are implemented,
+ *      only one of them would be respected, according to the priority list from high to low:
+ *      {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
  *
  * If an exception was throw when applying any of these query optimizations, the action would fail
  * and no Spark job was submitted.

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
new file mode 100644
index 0000000..27cf3a7
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanColumnarBatch.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import java.util.List;
+
+import org.apache.spark.annotation.InterfaceStability;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+
+/**
+ * A mix-in interface for {@link DataSourceV2Reader}. Data source readers can implement this
+ * interface to output {@link ColumnarBatch} and make the scan faster.
+ */
+@InterfaceStability.Evolving
+public interface SupportsScanColumnarBatch extends DataSourceV2Reader {
+  @Override
+  default List<ReadTask<Row>> createReadTasks() {
+    throw new IllegalStateException(
+      "createReadTasks not supported by default within SupportsScanColumnarBatch.");
+  }
+
+  /**
+   * Similar to {@link DataSourceV2Reader#createReadTasks()}, but returns columnar data in batches.
+   */
+  List<ReadTask<ColumnarBatch>> createBatchReadTasks();
+
+  /**
+   * Returns true if the concrete data source reader can read data in batch according to the scan
+   * properties like required columns, pushes filters, etc. It's possible that the implementation
+   * can only support some certain columns with certain types. Users can overwrite this method and
+   * {@link #createReadTasks()} to fallback to normal read path under some conditions.
+   */
+  default boolean enableBatchRead() {
+    return true;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
index b90ec88..2d3ad0e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsScanUnsafeRow.java
@@ -35,7 +35,7 @@ public interface SupportsScanUnsafeRow extends DataSourceV2Reader {
   @Override
   default List<ReadTask<Row>> createReadTasks() {
     throw new IllegalStateException(
-        "createReadTasks should not be called with SupportsScanUnsafeRow.");
+      "createReadTasks not supported by default within SupportsScanUnsafeRow");
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index 5617046..dd68df9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.DataType
@@ -25,13 +25,16 @@ import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
 
 
 /**
- * Helper trait for abstracting scan functionality using
- * [[ColumnarBatch]]es.
+ * Helper trait for abstracting scan functionality using [[ColumnarBatch]]es.
  */
 private[sql] trait ColumnarBatchScan extends CodegenSupport {
 
   def vectorTypes: Option[Seq[String]] = None
 
+  protected def supportsBatch: Boolean = true
+
+  protected def needsUnsafeRowConversion: Boolean = true
+
   override lazy val metrics = Map(
     "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
     "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
@@ -71,7 +74,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
     // PhysicalRDD always just has one input
     val input = ctx.addMutableState("scala.collection.Iterator", "input",
       v => s"$v = inputs[0];")
+    if (supportsBatch) {
+      produceBatches(ctx, input)
+    } else {
+      produceRows(ctx, input)
+    }
+  }
 
+  private def produceBatches(ctx: CodegenContext, input: String): String = {
     // metrics
     val numOutputRows = metricTerm(ctx, "numOutputRows")
     val scanTimeMetric = metricTerm(ctx, "scanTime")
@@ -137,4 +147,25 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport {
      """.stripMargin
   }
 
+  private def produceRows(ctx: CodegenContext, input: String): String = {
+    val numOutputRows = metricTerm(ctx, "numOutputRows")
+    val row = ctx.freshName("row")
+
+    ctx.INPUT_ROW = row
+    ctx.currentVars = null
+    // Always provide `outputVars`, so that the framework can help us build unsafe row if the input
+    // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true.
+    val outputVars = output.zipWithIndex.map { case (a, i) =>
+      BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+    }
+    val inputRow = if (needsUnsafeRowConversion) null else row
+    s"""
+       |while ($input.hasNext()) {
+       |  InternalRow $row = (InternalRow) $input.next();
+       |  $numOutputRows.add(1);
+       |  ${consume(ctx, outputVars, inputRow).trim}
+       |  if (shouldStop()) return;
+       |}
+     """.stripMargin
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index d1ff82c..7c7d79c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -164,13 +164,15 @@ case class FileSourceScanExec(
     override val tableIdentifier: Option[TableIdentifier])
   extends DataSourceScanExec with ColumnarBatchScan  {
 
-  val supportsBatch: Boolean = relation.fileFormat.supportBatch(
+  override val supportsBatch: Boolean = relation.fileFormat.supportBatch(
     relation.sparkSession, StructType.fromAttributes(output))
 
-  val needsUnsafeRowConversion: Boolean = if (relation.fileFormat.isInstanceOf[ParquetSource]) {
-    SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled
-  } else {
-    false
+  override val needsUnsafeRowConversion: Boolean = {
+    if (relation.fileFormat.isInstanceOf[ParquetSource]) {
+      SparkSession.getActiveSession.get.sessionState.conf.parquetVectorizedReaderEnabled
+    } else {
+      false
+    }
   }
 
   override def vectorTypes: Option[Seq[String]] =
@@ -346,33 +348,6 @@ case class FileSourceScanExec(
 
   override val nodeNamePrefix: String = "File"
 
-  override protected def doProduce(ctx: CodegenContext): String = {
-    if (supportsBatch) {
-      return super.doProduce(ctx)
-    }
-    val numOutputRows = metricTerm(ctx, "numOutputRows")
-    // PhysicalRDD always just has one input
-    val input = ctx.addMutableState("scala.collection.Iterator", "input", v => s"$v = inputs[0];")
-    val row = ctx.freshName("row")
-
-    ctx.INPUT_ROW = row
-    ctx.currentVars = null
-    // Always provide `outputVars`, so that the framework can help us build unsafe row if the input
-    // row is not unsafe row, i.e. `needsUnsafeRowConversion` is true.
-    val outputVars = output.zipWithIndex.map{ case (a, i) =>
-      BoundReference(i, a.dataType, a.nullable).genCode(ctx)
-    }
-    val inputRow = if (needsUnsafeRowConversion) null else row
-    s"""
-       |while ($input.hasNext()) {
-       |  InternalRow $row = (InternalRow) $input.next();
-       |  $numOutputRows.add(1);
-       |  ${consume(ctx, outputVars, inputRow).trim}
-       |  if (shouldStop()) return;
-       |}
-     """.stripMargin
-  }
-
   /**
    * Create an RDD for bucketed reads.
    * The non-bucketed variant of this function is [[createNonBucketedReadRDD]].

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index 933b975..3565ee3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -49,9 +49,9 @@ case class InMemoryTableScanExec(
 
   /**
    * If true, get data from ColumnVector in ColumnarBatch, which are generally faster.
-   * If false, get data from UnsafeRow build from ColumnVector
+   * If false, get data from UnsafeRow build from CachedBatch
    */
-  override val supportCodegen: Boolean = {
+  override val supportsBatch: Boolean = {
     // In the initial implementation, for ease of review
     // support only primitive data types and # of fields is less than wholeStageMaxNumFields
     relation.schema.fields.forall(f => f.dataType match {
@@ -61,6 +61,8 @@ case class InMemoryTableScanExec(
     }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema)
   }
 
+  override protected def needsUnsafeRowConversion: Boolean = false
+
   private val columnIndices =
     attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray
 
@@ -90,14 +92,56 @@ case class InMemoryTableScanExec(
     columnarBatch
   }
 
-  override def inputRDDs(): Seq[RDD[InternalRow]] = {
-    assert(supportCodegen)
+  private lazy val inputRDD: RDD[InternalRow] = {
     val buffers = filteredCachedBatches()
-    // HACK ALERT: This is actually an RDD[ColumnarBatch].
-    // We're taking advantage of Scala's type erasure here to pass these batches along.
-    Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]])
+    if (supportsBatch) {
+      // HACK ALERT: This is actually an RDD[ColumnarBatch].
+      // We're taking advantage of Scala's type erasure here to pass these batches along.
+      buffers.map(createAndDecompressColumn).asInstanceOf[RDD[InternalRow]]
+    } else {
+      val numOutputRows = longMetric("numOutputRows")
+
+      if (enableAccumulatorsForTest) {
+        readPartitions.setValue(0)
+        readBatches.setValue(0)
+      }
+
+      // Using these variables here to avoid serialization of entire objects (if referenced
+      // directly) within the map Partitions closure.
+      val relOutput: AttributeSeq = relation.output
+
+      filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator =>
+        // Find the ordinals and data types of the requested columns.
+        val (requestedColumnIndices, requestedColumnDataTypes) =
+          attributes.map { a =>
+            relOutput.indexOf(a.exprId) -> a.dataType
+          }.unzip
+
+        // update SQL metrics
+        val withMetrics = cachedBatchIterator.map { batch =>
+          if (enableAccumulatorsForTest) {
+            readBatches.add(1)
+          }
+          numOutputRows += batch.numRows
+          batch
+        }
+
+        val columnTypes = requestedColumnDataTypes.map {
+          case udt: UserDefinedType[_] => udt.sqlType
+          case other => other
+        }.toArray
+        val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
+        columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
+        if (enableAccumulatorsForTest && columnarIterator.hasNext) {
+          readPartitions.add(1)
+        }
+        columnarIterator
+      }
+    }
   }
 
+  override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
+
   override def output: Seq[Attribute] = attributes
 
   private def updateAttribute(expr: Expression): Expression = {
@@ -185,7 +229,7 @@ case class InMemoryTableScanExec(
     }
   }
 
-  lazy val enableAccumulators: Boolean =
+  lazy val enableAccumulatorsForTest: Boolean =
     sqlContext.getConf("spark.sql.inMemoryTableScanStatistics.enable", "false").toBoolean
 
   // Accumulators used for testing purposes
@@ -230,43 +274,10 @@ case class InMemoryTableScanExec(
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
-    val numOutputRows = longMetric("numOutputRows")
-
-    if (enableAccumulators) {
-      readPartitions.setValue(0)
-      readBatches.setValue(0)
-    }
-
-    // Using these variables here to avoid serialization of entire objects (if referenced directly)
-    // within the map Partitions closure.
-    val relOutput: AttributeSeq = relation.output
-
-    filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator =>
-      // Find the ordinals and data types of the requested columns.
-      val (requestedColumnIndices, requestedColumnDataTypes) =
-        attributes.map { a =>
-          relOutput.indexOf(a.exprId) -> a.dataType
-        }.unzip
-
-      // update SQL metrics
-      val withMetrics = cachedBatchIterator.map { batch =>
-        if (enableAccumulators) {
-          readBatches.add(1)
-        }
-        numOutputRows += batch.numRows
-        batch
-      }
-
-      val columnTypes = requestedColumnDataTypes.map {
-        case udt: UserDefinedType[_] => udt.sqlType
-        case other => other
-      }.toArray
-      val columnarIterator = GenerateColumnAccessor.generate(columnTypes)
-      columnarIterator.initialize(withMetrics, columnTypes, requestedColumnIndices.toArray)
-      if (enableAccumulators && columnarIterator.hasNext) {
-        readPartitions.add(1)
-      }
-      columnarIterator
+    if (supportsBatch) {
+      WholeStageCodegenExec(this).execute()
+    } else {
+      inputRDD
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 5f30be5..ac104d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -18,19 +18,19 @@
 package org.apache.spark.sql.execution.datasources.v2
 
 import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
 
 import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.sources.v2.reader.ReadTask
 
-class DataSourceRDDPartition(val index: Int, val readTask: ReadTask[UnsafeRow])
+class DataSourceRDDPartition[T : ClassTag](val index: Int, val readTask: ReadTask[T])
   extends Partition with Serializable
 
-class DataSourceRDD(
+class DataSourceRDD[T: ClassTag](
     sc: SparkContext,
-    @transient private val readTasks: java.util.List[ReadTask[UnsafeRow]])
-  extends RDD[UnsafeRow](sc, Nil) {
+    @transient private val readTasks: java.util.List[ReadTask[T]])
+  extends RDD[T](sc, Nil) {
 
   override protected def getPartitions: Array[Partition] = {
     readTasks.asScala.zipWithIndex.map {
@@ -38,10 +38,10 @@ class DataSourceRDD(
     }.toArray
   }
 
-  override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
-    val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader()
+  override def compute(split: Partition, context: TaskContext): Iterator[T] = {
+    val reader = split.asInstanceOf[DataSourceRDDPartition[T]].readTask.createDataReader()
     context.addTaskCompletionListener(_ => reader.close())
-    val iter = new Iterator[UnsafeRow] {
+    val iter = new Iterator[T] {
       private[this] var valuePrepared = false
 
       override def hasNext: Boolean = {
@@ -51,7 +51,7 @@ class DataSourceRDD(
         valuePrepared
       }
 
-      override def next(): UnsafeRow = {
+      override def next(): T = {
         if (!hasNext) {
           throw new java.util.NoSuchElementException("End of stream")
         }
@@ -63,6 +63,6 @@ class DataSourceRDD(
   }
 
   override def getPreferredLocations(split: Partition): Seq[String] = {
-    split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations()
+    split.asInstanceOf[DataSourceRDDPartition[T]].readTask.preferredLocations()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index 49c506b..8c64df0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -24,10 +24,8 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.LeafExecNode
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.execution.streaming.StreamExecution
-import org.apache.spark.sql.execution.streaming.continuous.{ContinuousDataSourceRDD, ContinuousExecution, EpochCoordinatorRef, SetReaderPartitions}
+import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.streaming.continuous._
 import org.apache.spark.sql.sources.v2.reader._
 import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader
 import org.apache.spark.sql.types.StructType
@@ -37,40 +35,56 @@ import org.apache.spark.sql.types.StructType
  */
 case class DataSourceV2ScanExec(
     fullOutput: Seq[AttributeReference],
-    @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder {
+    @transient reader: DataSourceV2Reader)
+  extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {
 
   override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]
 
-  override def references: AttributeSet = AttributeSet.empty
+  override def producedAttributes: AttributeSet = AttributeSet(fullOutput)
 
-  override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
+  private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match {
+    case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks()
+    case _ =>
+      reader.createReadTasks().asScala.map {
+        new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow]
+      }.asJava
+  }
 
-  override protected def doExecute(): RDD[InternalRow] = {
-    val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match {
-      case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks()
-      case _ =>
-        reader.createReadTasks().asScala.map {
-          new RowToUnsafeRowReadTask(_, reader.readSchema()): ReadTask[UnsafeRow]
-        }.asJava
-    }
+  private lazy val inputRDD: RDD[InternalRow] = reader match {
+    case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
+      assert(!reader.isInstanceOf[ContinuousReader],
+        "continuous stream reader does not support columnar read yet.")
+      new DataSourceRDD(sparkContext, r.createBatchReadTasks()).asInstanceOf[RDD[InternalRow]]
+
+    case _: ContinuousReader =>
+      EpochCoordinatorRef.get(
+        sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
+        .askSync[Unit](SetReaderPartitions(readTasks.size()))
+      new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
+        .asInstanceOf[RDD[InternalRow]]
+
+    case _ =>
+      new DataSourceRDD(sparkContext, readTasks).asInstanceOf[RDD[InternalRow]]
+  }
 
-    val inputRDD = reader match {
-      case _: ContinuousReader =>
-        EpochCoordinatorRef.get(
-          sparkContext.getLocalProperty(ContinuousExecution.RUN_ID_KEY), sparkContext.env)
-          .askSync[Unit](SetReaderPartitions(readTasks.size()))
+  override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(inputRDD)
 
-        new ContinuousDataSourceRDD(sparkContext, sqlContext, readTasks)
+  override val supportsBatch: Boolean = reader match {
+    case r: SupportsScanColumnarBatch if r.enableBatchRead() => true
+    case _ => false
+  }
 
-      case _ =>
-        new DataSourceRDD(sparkContext, readTasks)
-    }
+  override protected def needsUnsafeRowConversion: Boolean = false
 
-    val numOutputRows = longMetric("numOutputRows")
-    inputRDD.asInstanceOf[RDD[InternalRow]].map { r =>
-      numOutputRows += 1
-      r
+  override protected def doExecute(): RDD[InternalRow] = {
+    if (supportsBatch) {
+      WholeStageCodegenExec(this).execute()
+    } else {
+      val numOutputRows = longMetric("numOutputRows")
+      inputRDD.map { r =>
+        numOutputRows += 1
+        r
+      }
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index d79e4bd..b3f1a1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -52,7 +52,7 @@ class ContinuousDataSourceRDD(
   }
 
   override def compute(split: Partition, context: TaskContext): Iterator[UnsafeRow] = {
-    val reader = split.asInstanceOf[DataSourceRDDPartition].readTask.createDataReader()
+    val reader = split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.createDataReader()
 
     val runId = context.getLocalProperty(ContinuousExecution.RUN_ID_KEY)
 
@@ -132,7 +132,7 @@ class ContinuousDataSourceRDD(
   }
 
   override def getPreferredLocations(split: Partition): Seq[String] = {
-    split.asInstanceOf[DataSourceRDDPartition].readTask.preferredLocations()
+    split.asInstanceOf[DataSourceRDDPartition[UnsafeRow]].readTask.preferredLocations()
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
new file mode 100644
index 0000000..44e5146
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources.v2;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.ReadSupport;
+import org.apache.spark.sql.sources.v2.reader.*;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+
+
+public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport {
+
+  class Reader implements DataSourceV2Reader, SupportsScanColumnarBatch {
+    private final StructType schema = new StructType().add("i", "int").add("j", "int");
+
+    @Override
+    public StructType readSchema() {
+      return schema;
+    }
+
+    @Override
+    public List<ReadTask<ColumnarBatch>> createBatchReadTasks() {
+      return java.util.Arrays.asList(new JavaBatchReadTask(0, 50), new JavaBatchReadTask(50, 90));
+    }
+  }
+
+  static class JavaBatchReadTask implements ReadTask<ColumnarBatch>, DataReader<ColumnarBatch> {
+    private int start;
+    private int end;
+
+    private static final int BATCH_SIZE = 20;
+
+    private OnHeapColumnVector i;
+    private OnHeapColumnVector j;
+    private ColumnarBatch batch;
+
+    JavaBatchReadTask(int start, int end) {
+      this.start = start;
+      this.end = end;
+    }
+
+    @Override
+    public DataReader<ColumnarBatch> createDataReader() {
+      this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
+      this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType);
+      ColumnVector[] vectors = new ColumnVector[2];
+      vectors[0] = i;
+      vectors[1] = j;
+      this.batch = new ColumnarBatch(new StructType().add("i", "int").add("j", "int"), vectors, BATCH_SIZE);
+      return this;
+    }
+
+    @Override
+    public boolean next() {
+      i.reset();
+      j.reset();
+      int count = 0;
+      while (start < end && count < BATCH_SIZE) {
+        i.putInt(count, start);
+        j.putInt(count, -start);
+        start += 1;
+        count += 1;
+      }
+
+      if (count == 0) {
+        return false;
+      } else {
+        batch.setNumRows(count);
+        return true;
+      }
+    }
+
+    @Override
+    public ColumnarBatch get() {
+      return batch;
+    }
+
+    @Override
+    public void close() throws IOException {
+      batch.close();
+    }
+  }
+
+
+  @Override
+  public DataSourceV2Reader createReader(DataSourceV2Options options) {
+    return new Reader();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index bc05dca..22ca128 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -121,31 +121,23 @@ class WholeStageCodegenSuite extends QueryTest with SharedSQLContext {
   test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") {
     import testImplicits._
 
-    val dsInt = spark.range(3).cache
-    dsInt.count
+    val dsInt = spark.range(3).cache()
+    dsInt.count()
     val dsIntFilter = dsInt.filter(_ > 0)
     val planInt = dsIntFilter.queryExecution.executedPlan
-    assert(planInt.find(p =>
-      p.isInstanceOf[WholeStageCodegenExec] &&
-      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
-      p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
-        .isInstanceOf[InMemoryTableScanExec] &&
-      p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
-        .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined
-    )
+    assert(planInt.collect {
+      case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if i.supportsBatch => ()
+    }.length == 1)
     assert(dsIntFilter.collect() === Array(1, 2))
 
     // cache for string type is not supported for InMemoryTableScanExec
-    val dsString = spark.range(3).map(_.toString).cache
-    dsString.count
+    val dsString = spark.range(3).map(_.toString).cache()
+    dsString.count()
     val dsStringFilter = dsString.filter(_ == "1")
     val planString = dsStringFilter.queryExecution.executedPlan
-    assert(planString.find(p =>
-      p.isInstanceOf[WholeStageCodegenExec] &&
-      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] &&
-      !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child
-        .isInstanceOf[InMemoryTableScanExec]).isDefined
-    )
+    assert(planString.collect {
+      case WholeStageCodegenExec(FilterExec(_, i: InMemoryTableScanExec)) if !i.supportsBatch => ()
+    }.length == 1)
     assert(dsStringFilter.collect() === Array("1"))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/75db1486/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index ab37e49..a89f7c5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -24,10 +24,12 @@ import test.org.apache.spark.sql.sources.v2._
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.sources.{Filter, GreaterThan}
 import org.apache.spark.sql.sources.v2.reader._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.vectorized.ColumnarBatch
 
 class DataSourceV2Suite extends QueryTest with SharedSQLContext {
   import testImplicits._
@@ -56,7 +58,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
     }
   }
 
-  test("unsafe row implementation") {
+  test("unsafe row scan implementation") {
     Seq(classOf[UnsafeRowDataSourceV2], classOf[JavaUnsafeRowDataSourceV2]).foreach { cls =>
       withClue(cls.getName) {
         val df = spark.read.format(cls.getName).load()
@@ -67,6 +69,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("columnar batch scan implementation") {
+    Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls =>
+      withClue(cls.getName) {
+        val df = spark.read.format(cls.getName).load()
+        checkAnswer(df, (0 until 90).map(i => Row(i, -i)))
+        checkAnswer(df.select('j), (0 until 90).map(i => Row(-i)))
+        checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i)))
+      }
+    }
+  }
+
   test("schema required data source") {
     Seq(classOf[SchemaRequiredDataSource], classOf[JavaSchemaRequiredDataSource]).foreach { cls =>
       withClue(cls.getName) {
@@ -275,7 +288,7 @@ class UnsafeRowReadTask(start: Int, end: Int)
 
   private var current = start - 1
 
-  override def createDataReader(): DataReader[UnsafeRow] = new UnsafeRowReadTask(start, end)
+  override def createDataReader(): DataReader[UnsafeRow] = this
 
   override def next(): Boolean = {
     current += 1
@@ -300,3 +313,56 @@ class SchemaRequiredDataSource extends DataSourceV2 with ReadSupportWithSchema {
   override def createReader(schema: StructType, options: DataSourceV2Options): DataSourceV2Reader =
     new Reader(schema)
 }
+
+class BatchDataSourceV2 extends DataSourceV2 with ReadSupport {
+
+  class Reader extends DataSourceV2Reader with SupportsScanColumnarBatch {
+    override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
+
+    override def createBatchReadTasks(): JList[ReadTask[ColumnarBatch]] = {
+      java.util.Arrays.asList(new BatchReadTask(0, 50), new BatchReadTask(50, 90))
+    }
+  }
+
+  override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
+}
+
+class BatchReadTask(start: Int, end: Int)
+  extends ReadTask[ColumnarBatch] with DataReader[ColumnarBatch] {
+
+  private final val BATCH_SIZE = 20
+  private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
+  private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType)
+  private lazy val batch = new ColumnarBatch(
+    new StructType().add("i", "int").add("j", "int"), Array(i, j), BATCH_SIZE)
+
+  private var current = start
+
+  override def createDataReader(): DataReader[ColumnarBatch] = this
+
+  override def next(): Boolean = {
+    i.reset()
+    j.reset()
+
+    var count = 0
+    while (current < end && count < BATCH_SIZE) {
+      i.putInt(count, current)
+      j.putInt(count, -current)
+      current += 1
+      count += 1
+    }
+
+    if (count == 0) {
+      false
+    } else {
+      batch.setNumRows(count)
+      true
+    }
+  }
+
+  override def get(): ColumnarBatch = {
+    batch
+  }
+
+  override def close(): Unit = batch.close()
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org