You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/03/20 18:46:54 UTC

spark git commit: [SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory.

Repository: spark
Updated Branches:
  refs/heads/master 7f5e8aa26 -> 2c4b9962f


[SPARK-23574][SQL] Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory.

## What changes were proposed in this pull request?

Report SinglePartition in DataSourceV2ScanExec when there's exactly 1 data reader factory.

Note that this means reader factories end up being constructed as partitioning is checked; let me know if you think that could be a problem.

## How was this patch tested?

existing unit tests

Author: Jose Torres <jo...@databricks.com>
Author: Jose Torres <to...@gmail.com>

Closes #20726 from jose-torres/SPARK-23574.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2c4b9962
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2c4b9962
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2c4b9962

Branch: refs/heads/master
Commit: 2c4b9962fdf8c1beb66126ca41628c72eb6c2383
Parents: 7f5e8aa
Author: Jose Torres <jo...@databricks.com>
Authored: Tue Mar 20 11:46:51 2018 -0700
Committer: Wenchen Fan <we...@databricks.com>
Committed: Tue Mar 20 11:46:51 2018 -0700

----------------------------------------------------------------------
 .../v2/reader/SupportsReportPartitioning.java   |  3 ++
 .../datasources/v2/DataSourceRDD.scala          |  4 +--
 .../datasources/v2/DataSourceV2ScanExec.scala   | 29 +++++++++++++++-----
 .../ContinuousDataSourceRDDIter.scala           |  4 +--
 .../sql/sources/v2/DataSourceV2Suite.scala      | 20 +++++++++++++-
 .../sql/streaming/StreamingQuerySuite.scala     |  4 +--
 6 files changed, 50 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2c4b9962/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
index 5405a91..6076287 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
@@ -23,6 +23,9 @@ import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning;
 /**
  * A mix in interface for {@link DataSourceReader}. Data source readers can implement this
  * interface to report data partitioning and try to avoid shuffle at Spark side.
+ *
+ * Note that, when the reader creates exactly one {@link DataReaderFactory}, Spark may avoid
+ * adding a shuffle even if the reader does not implement this interface.
  */
 @InterfaceStability.Evolving
 public interface SupportsReportPartitioning extends DataSourceReader {

http://git-wip-us.apache.org/repos/asf/spark/blob/2c4b9962/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 5ed0ba7..f85971b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -29,11 +29,11 @@ class DataSourceRDDPartition[T : ClassTag](val index: Int, val readerFactory: Da
 
 class DataSourceRDD[T: ClassTag](
     sc: SparkContext,
-    @transient private val readerFactories: java.util.List[DataReaderFactory[T]])
+    @transient private val readerFactories: Seq[DataReaderFactory[T]])
   extends RDD[T](sc, Nil) {
 
   override protected def getPartitions: Array[Partition] = {
-    readerFactories.asScala.zipWithIndex.map {
+    readerFactories.zipWithIndex.map {
       case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
     }.toArray
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c4b9962/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index cb691ba..3a5e7bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -25,12 +25,14 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
 import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.streaming.continuous._
 import org.apache.spark.sql.sources.v2.DataSourceV2
 import org.apache.spark.sql.sources.v2.reader._
 import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
 
 /**
  * Physical plan node for scanning data from a data source.
@@ -56,6 +58,15 @@ case class DataSourceV2ScanExec(
   }
 
   override def outputPartitioning: physical.Partitioning = reader match {
+    case r: SupportsScanColumnarBatch if r.enableBatchRead() && batchReaderFactories.size == 1 =>
+      SinglePartition
+
+    case r: SupportsScanColumnarBatch if !r.enableBatchRead() && readerFactories.size == 1 =>
+      SinglePartition
+
+    case r if !r.isInstanceOf[SupportsScanColumnarBatch] && readerFactories.size == 1 =>
+      SinglePartition
+
     case s: SupportsReportPartitioning =>
       new DataSourcePartitioning(
         s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
@@ -63,29 +74,33 @@ case class DataSourceV2ScanExec(
     case _ => super.outputPartitioning
   }
 
-  private lazy val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]] = reader match {
-    case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories()
+  private lazy val readerFactories: Seq[DataReaderFactory[UnsafeRow]] = reader match {
+    case r: SupportsScanUnsafeRow => r.createUnsafeRowReaderFactories().asScala
     case _ =>
       reader.createDataReaderFactories().asScala.map {
         new RowToUnsafeRowDataReaderFactory(_, reader.readSchema()): DataReaderFactory[UnsafeRow]
-      }.asJava
+      }
   }
 
-  private lazy val inputRDD: RDD[InternalRow] = reader match {
+  private lazy val batchReaderFactories: Seq[DataReaderFactory[ColumnarBatch]] = reader match {
     case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
       assert(!reader.isInstanceOf[ContinuousReader],
         "continuous stream reader does not support columnar read yet.")
-      new DataSourceRDD(sparkContext, r.createBatchDataReaderFactories())
-        .asInstanceOf[RDD[InternalRow]]
+      r.createBatchDataReaderFactories().asScala
+  }
 
+  private lazy val inputRDD: RDD[InternalRow] = reader match {
     case _: ContinuousReader =>
       EpochCoordinatorRef.get(
           sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY),
           sparkContext.env)
-        .askSync[Unit](SetReaderPartitions(readerFactories.size()))
+        .askSync[Unit](SetReaderPartitions(readerFactories.size))
       new ContinuousDataSourceRDD(sparkContext, sqlContext, readerFactories)
         .asInstanceOf[RDD[InternalRow]]
 
+    case r: SupportsScanColumnarBatch if r.enableBatchRead() =>
+      new DataSourceRDD(sparkContext, batchReaderFactories).asInstanceOf[RDD[InternalRow]]
+
     case _ =>
       new DataSourceRDD(sparkContext, readerFactories).asInstanceOf[RDD[InternalRow]]
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c4b9962/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index cf02c0d..06754f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -35,14 +35,14 @@ import org.apache.spark.util.ThreadUtils
 class ContinuousDataSourceRDD(
     sc: SparkContext,
     sqlContext: SQLContext,
-    @transient private val readerFactories: java.util.List[DataReaderFactory[UnsafeRow]])
+    @transient private val readerFactories: Seq[DataReaderFactory[UnsafeRow]])
   extends RDD[UnsafeRow](sc, Nil) {
 
   private val dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize
   private val epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs
 
   override protected def getPartitions: Array[Partition] = {
-    readerFactories.asScala.zipWithIndex.map {
+    readerFactories.zipWithIndex.map {
       case (readerFactory, index) => new DataSourceRDDPartition(index, readerFactory)
     }.toArray
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c4b9962/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index 1157a35..e0a5327 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec}
-import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
+import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec}
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.sources.{Filter, GreaterThan}
@@ -191,6 +191,11 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
     }
   }
 
+  test("SPARK-23574: no shuffle exchange with single partition") {
+    val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*"))
+    assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty)
+  }
+
   test("simple writable data source") {
     // TODO: java implementation.
     Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
@@ -336,6 +341,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
   }
 }
 
+class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport {
+
+  class Reader extends DataSourceReader {
+    override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int")
+
+    override def createDataReaderFactories(): JList[DataReaderFactory[Row]] = {
+      java.util.Arrays.asList(new SimpleDataReaderFactory(0, 5))
+    }
+  }
+
+  override def createReader(options: DataSourceOptions): DataSourceReader = new Reader
+}
+
 class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {
 
   class Reader extends DataSourceReader {

http://git-wip-us.apache.org/repos/asf/spark/blob/2c4b9962/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 3f9aa0d..ebc9a87 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -326,9 +326,9 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi
 
         assert(progress.durationMs.get("setOffsetRange") === 50)
         assert(progress.durationMs.get("getEndOffset") === 100)
-        assert(progress.durationMs.get("queryPlanning") === 0)
+        assert(progress.durationMs.get("queryPlanning") === 200)
         assert(progress.durationMs.get("walCommit") === 0)
-        assert(progress.durationMs.get("addBatch") === 350)
+        assert(progress.durationMs.get("addBatch") === 150)
         assert(progress.durationMs.get("triggerExecution") === 500)
 
         assert(progress.sources.length === 1)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org