You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/04/23 02:44:04 UTC
[4/7] spark git commit: [SPARK-14855][SQL] Add "Exec" suffix to
physical operators
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index 45a3213..971770a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -42,7 +42,7 @@ private[sql] trait RunnableCommand extends LogicalPlan with logical.Command {
* A physical operator that executes the run method of a `RunnableCommand` and
* saves the result to prevent multiple executions.
*/
-private[sql] case class ExecutedCommand(cmd: RunnableCommand) extends SparkPlan {
+private[sql] case class ExecutedCommandExec(cmd: RunnableCommand) extends SparkPlan {
/**
* A concrete command should override this lazy field to wrap up any side effects caused by the
* command or any other computation that should be evaluated exactly once. The value of this field
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index ac3c52e..9bebd74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -32,8 +32,8 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.DataSourceScan.{INPUT_PATHS, PUSHED_FILTERS}
-import org.apache.spark.sql.execution.command.ExecutedCommand
+import org.apache.spark.sql.execution.DataSourceScanExec.{INPUT_PATHS, PUSHED_FILTERS}
+import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -105,12 +105,12 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
(a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil
case l @ LogicalRelation(baseRelation: TableScan, _, _) =>
- execution.DataSourceScan.create(
+ execution.DataSourceScanExec.create(
l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil
case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
part, query, overwrite, false) if part.isEmpty =>
- ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil
+ ExecutedCommandExec(InsertIntoDataSource(l, query, overwrite)) :: Nil
case _ => Nil
}
@@ -214,22 +214,22 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
// Don't request columns that are only referenced by pushed filters.
.filterNot(handledSet.contains)
- val scan = execution.DataSourceScan.create(
+ val scan = execution.DataSourceScanExec.create(
projects.map(_.toAttribute),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation, metadata)
- filterCondition.map(execution.Filter(_, scan)).getOrElse(scan)
+ filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan)
} else {
// Don't request columns that are only referenced by pushed filters.
val requestedColumns =
(projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq
- val scan = execution.DataSourceScan.create(
+ val scan = execution.DataSourceScanExec.create(
requestedColumns,
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation, metadata)
- execution.Project(
- projects, filterCondition.map(execution.Filter(_, scan)).getOrElse(scan))
+ execution.ProjectExec(
+ projects, filterCondition.map(execution.FilterExec(_, scan)).getOrElse(scan))
}
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index c1a97de..751daa0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.{DataSourceScan, SparkPlan}
+import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan}
/**
* A strategy for planning scans over collections of files that might be partitioned or bucketed
@@ -192,7 +192,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
}
val scan =
- DataSourceScan.create(
+ DataSourceScanExec.create(
readDataColumns ++ partitionColumns,
new FileScanRDD(
files.sqlContext,
@@ -205,11 +205,11 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
"ReadSchema" -> prunedDataSchema.simpleString))
val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And)
- val withFilter = afterScanFilter.map(execution.Filter(_, scan)).getOrElse(scan)
+ val withFilter = afterScanFilter.map(execution.FilterExec(_, scan)).getOrElse(scan)
val withProjections = if (projects == withFilter.output) {
withFilter
} else {
- execution.Project(projects, withFilter)
+ execution.ProjectExec(projects, withFilter)
}
withProjections :: Nil
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index e6079ec..5b96ab1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -49,9 +49,9 @@ package object debug {
}
def codegenString(plan: SparkPlan): String = {
- val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]()
+ val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]()
plan transform {
- case s: WholeStageCodegen =>
+ case s: WholeStageCodegenExec =>
codegenSubtrees += s
s
case s => s
@@ -86,11 +86,11 @@ package object debug {
val debugPlan = plan transform {
case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) =>
visited += new TreeNodeRef(s)
- DebugNode(s)
+ DebugExec(s)
}
debugPrint(s"Results returned: ${debugPlan.execute().count()}")
debugPlan.foreach {
- case d: DebugNode => d.dumpStats()
+ case d: DebugExec => d.dumpStats()
case _ =>
}
}
@@ -104,7 +104,7 @@ package object debug {
}
}
- private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport {
+ private[sql] case class DebugExec(child: SparkPlan) extends UnaryExecNode with CodegenSupport {
def output: Seq[Attribute] = child.output
implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
deleted file mode 100644
index 87a113e..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.exchange
-
-import scala.concurrent.{ExecutionContext, Future}
-import scala.concurrent.duration._
-
-import org.apache.spark.broadcast
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
-import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.ThreadUtils
-
-/**
- * A [[BroadcastExchange]] collects, transforms and finally broadcasts the result of a transformed
- * SparkPlan.
- */
-case class BroadcastExchange(
- mode: BroadcastMode,
- child: SparkPlan) extends Exchange {
-
- override private[sql] lazy val metrics = Map(
- "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"),
- "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"),
- "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"),
- "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)"))
-
- override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
-
- override def sameResult(plan: SparkPlan): Boolean = plan match {
- case p: BroadcastExchange =>
- mode.compatibleWith(p.mode) && child.sameResult(p.child)
- case _ => false
- }
-
- @transient
- private val timeout: Duration = {
- val timeoutValue = sqlContext.conf.broadcastTimeout
- if (timeoutValue < 0) {
- Duration.Inf
- } else {
- timeoutValue.seconds
- }
- }
-
- @transient
- private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
- // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- Future {
- // This will run in another thread. Set the execution id so that we can connect these jobs
- // with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
- val beforeCollect = System.nanoTime()
- // Note that we use .executeCollect() because we don't want to convert data to Scala types
- val input: Array[InternalRow] = child.executeCollect()
- val beforeBuild = System.nanoTime()
- longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
- longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
-
- // Construct and broadcast the relation.
- val relation = mode.transform(input)
- val beforeBroadcast = System.nanoTime()
- longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
-
- val broadcasted = sparkContext.broadcast(relation)
- longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
- broadcasted
- }
- }(BroadcastExchange.executionContext)
- }
-
- override protected def doPrepare(): Unit = {
- // Materialize the future.
- relationFuture
- }
-
- override protected def doExecute(): RDD[InternalRow] = {
- throw new UnsupportedOperationException(
- "BroadcastExchange does not support the execute() code path.")
- }
-
- override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
- ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
- }
-}
-
-object BroadcastExchange {
- private[execution] val executionContext = ExecutionContext.fromExecutorService(
- ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128))
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
new file mode 100644
index 0000000..573ca19
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration._
+
+import org.apache.spark.broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
+import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A [[BroadcastExchangeExec]] collects, transforms and finally broadcasts the result of
+ * a transformed SparkPlan.
+ */
+case class BroadcastExchangeExec(
+ mode: BroadcastMode,
+ child: SparkPlan) extends Exchange {
+
+ override private[sql] lazy val metrics = Map(
+ "dataSize" -> SQLMetrics.createLongMetric(sparkContext, "data size (bytes)"),
+ "collectTime" -> SQLMetrics.createLongMetric(sparkContext, "time to collect (ms)"),
+ "buildTime" -> SQLMetrics.createLongMetric(sparkContext, "time to build (ms)"),
+ "broadcastTime" -> SQLMetrics.createLongMetric(sparkContext, "time to broadcast (ms)"))
+
+ override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
+
+ override def sameResult(plan: SparkPlan): Boolean = plan match {
+ case p: BroadcastExchangeExec =>
+ mode.compatibleWith(p.mode) && child.sameResult(p.child)
+ case _ => false
+ }
+
+ @transient
+ private val timeout: Duration = {
+ val timeoutValue = sqlContext.conf.broadcastTimeout
+ if (timeoutValue < 0) {
+ Duration.Inf
+ } else {
+ timeoutValue.seconds
+ }
+ }
+
+ @transient
+ private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+ // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ Future {
+ // This will run in another thread. Set the execution id so that we can connect these jobs
+ // with the correct execution.
+ SQLExecution.withExecutionId(sparkContext, executionId) {
+ val beforeCollect = System.nanoTime()
+ // Note that we use .executeCollect() because we don't want to convert data to Scala types
+ val input: Array[InternalRow] = child.executeCollect()
+ val beforeBuild = System.nanoTime()
+ longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
+ longMetric("dataSize") += input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
+
+ // Construct and broadcast the relation.
+ val relation = mode.transform(input)
+ val beforeBroadcast = System.nanoTime()
+ longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000
+
+ val broadcasted = sparkContext.broadcast(relation)
+ longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000
+ broadcasted
+ }
+ }(BroadcastExchangeExec.executionContext)
+ }
+
+ override protected def doPrepare(): Unit = {
+ // Materialize the future.
+ relationFuture
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ "BroadcastExchange does not support the execute() code path.")
+ }
+
+ override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]]
+ }
+}
+
+object BroadcastExchangeExec {
+ private[execution] val executionContext = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128))
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index 4864db7..446571a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -160,7 +160,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
child
case (child, BroadcastDistribution(mode)) =>
- BroadcastExchange(mode, child)
+ BroadcastExchangeExec(mode, child)
case (child, distribution) =>
ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
}
@@ -237,7 +237,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
if (requiredOrdering.nonEmpty) {
// If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
- Sort(requiredOrdering, global = false, child = child)
+ SortExec(requiredOrdering, global = false, child = child)
} else {
child
}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index df7ad48..9da9df6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
+import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType
* differs significantly, the concept is similar to the exchange operator described in
* "Volcano -- An Extensible and Parallel Query Evaluation System" by Goetz Graefe.
*/
-abstract class Exchange extends UnaryNode {
+abstract class Exchange extends UnaryExecNode {
override def output: Seq[Attribute] = child.output
}
@@ -45,7 +45,8 @@ abstract class Exchange extends UnaryNode {
* logically identical output will have distinct sets of output attribute ids, so we need to
* preserve the original ids because they're what downstream operators are expecting.
*/
-case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode {
+case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange)
+ extends LeafExecNode {
override def sameResult(plan: SparkPlan): Boolean = {
// Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here.
@@ -86,7 +87,7 @@ case class ReuseExchange(conf: SQLConf) extends Rule[SparkPlan] {
if (samePlan.isDefined) {
// Keep the output of this exchange, the following plans require that to resolve
// attributes.
- ReusedExchange(exchange.output, samePlan.get)
+ ReusedExchangeExec(exchange.output, samePlan.get)
} else {
sameSchema += exchange
exchange
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
deleted file mode 100644
index 89487c6..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ /dev/null
@@ -1,401 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.TaskContext
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.LongType
-
-/**
- * Performs an inner hash join of two child relations. When the output RDD of this operator is
- * being constructed, a Spark job is asynchronously started to calculate the values for the
- * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed
- * relation is not shuffled.
- */
-case class BroadcastHashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- joinType: JoinType,
- buildSide: BuildSide,
- condition: Option[Expression],
- left: SparkPlan,
- right: SparkPlan)
- extends BinaryNode with HashJoin with CodegenSupport {
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
-
- override def requiredChildDistribution: Seq[Distribution] = {
- val mode = HashedRelationBroadcastMode(buildKeys)
- buildSide match {
- case BuildLeft =>
- BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
- case BuildRight =>
- UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
- }
- }
-
- protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
- streamedPlan.execute().mapPartitions { streamedIter =>
- val hashed = broadcastRelation.value.asReadOnlyCopy()
- TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
- join(streamedIter, hashed, numOutputRows)
- }
- }
-
- override def inputRDDs(): Seq[RDD[InternalRow]] = {
- streamedPlan.asInstanceOf[CodegenSupport].inputRDDs()
- }
-
- override def doProduce(ctx: CodegenContext): String = {
- streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
- }
-
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
- joinType match {
- case Inner => codegenInner(ctx, input)
- case LeftOuter | RightOuter => codegenOuter(ctx, input)
- case LeftSemi => codegenSemi(ctx, input)
- case LeftAnti => codegenAnti(ctx, input)
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastHashJoin should not take $x as the JoinType")
- }
- }
-
- /**
- * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
- */
- private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
- // create a name for HashedRelation
- val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
- val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
- val relationTerm = ctx.freshName("relation")
- val clsName = broadcastRelation.value.getClass.getName
- ctx.addMutableState(clsName, relationTerm,
- s"""
- | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
- | incPeakExecutionMemory($relationTerm.estimatedSize());
- """.stripMargin)
- (broadcastRelation, relationTerm)
- }
-
- /**
- * Returns the code for generating join key for stream side, and expression of whether the key
- * has any null in it or not.
- */
- private def genStreamSideJoinKey(
- ctx: CodegenContext,
- input: Seq[ExprCode]): (ExprCode, String) = {
- ctx.currentVars = input
- if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
- // generate the join key as Long
- val ev = streamedKeys.head.genCode(ctx)
- (ev, ev.isNull)
- } else {
- // generate the join key as UnsafeRow
- val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
- (ev, s"${ev.value}.anyNull()")
- }
- }
-
- /**
- * Generates the code for variable of build side.
- */
- private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
- ctx.currentVars = null
- ctx.INPUT_ROW = matched
- buildPlan.output.zipWithIndex.map { case (a, i) =>
- val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
- if (joinType == Inner) {
- ev
- } else {
- // the variables are needed even there is no matched rows
- val isNull = ctx.freshName("isNull")
- val value = ctx.freshName("value")
- val code = s"""
- |boolean $isNull = true;
- |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
- |if ($matched != null) {
- | ${ev.code}
- | $isNull = ${ev.isNull};
- | $value = ${ev.value};
- |}
- """.stripMargin
- ExprCode(code, isNull, value)
- }
- }
- }
-
- /**
- * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi
- * and Left Anti joins.
- */
- private def getJoinCondition(
- ctx: CodegenContext,
- input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
- // filter the output via condition
- ctx.currentVars = input ++ buildVars
- val ev =
- BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
- s"""
- |$eval
- |${ev.code}
- |if (${ev.isNull} || !${ev.value}) continue;
- """.stripMargin
- } else {
- ""
- }
- (matched, checkCondition, buildVars)
- }
-
- /**
- * Generates the code for Inner join.
- */
- private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |if ($matched == null) continue;
- |$checkCondition
- |$numOutput.add(1);
- |${consume(ctx, resultVars)}
- """.stripMargin
-
- } else {
- ctx.copyResult = true
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
- |if ($matches == null) continue;
- |while ($matches.hasNext()) {
- | UnsafeRow $matched = (UnsafeRow) $matches.next();
- | $checkCondition
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- |}
- """.stripMargin
- }
- }
-
- /**
- * Generates the code for left or right outer join.
- */
- private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val matched = ctx.freshName("matched")
- val buildVars = genBuildSideVars(ctx, matched)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- // filter the output via condition
- val conditionPassed = ctx.freshName("conditionPassed")
- val checkCondition = if (condition.isDefined) {
- val expr = condition.get
- // evaluate the variables from build side that used by condition
- val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
- ctx.currentVars = input ++ buildVars
- val ev =
- BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
- s"""
- |boolean $conditionPassed = true;
- |${eval.trim}
- |${ev.code}
- |if ($matched != null) {
- | $conditionPassed = !${ev.isNull} && ${ev.value};
- |}
- """.stripMargin
- } else {
- s"final boolean $conditionPassed = true;"
- }
-
- val resultVars = buildSide match {
- case BuildLeft => buildVars ++ input
- case BuildRight => input ++ buildVars
- }
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |${checkCondition.trim}
- |if (!$conditionPassed) {
- | $matched = null;
- | // reset the variables those are already evaluated.
- | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")}
- |}
- |$numOutput.add(1);
- |${consume(ctx, resultVars)}
- """.stripMargin
-
- } else {
- ctx.copyResult = true
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val found = ctx.freshName("found")
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
- |boolean $found = false;
- |// the last iteration of this loop is to emit an empty row if there is no matched rows.
- |while ($matches != null && $matches.hasNext() || !$found) {
- | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
- | (UnsafeRow) $matches.next() : null;
- | ${checkCondition.trim}
- | if (!$conditionPassed) continue;
- | $found = true;
- | $numOutput.add(1);
- | ${consume(ctx, resultVars)}
- |}
- """.stripMargin
- }
- }
-
- /**
- * Generates the code for left semi join.
- */
- private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val (matched, checkCondition, _) = getJoinCondition(ctx, input)
- val numOutput = metricTerm(ctx, "numOutputRows")
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashedRelation
- |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- |if ($matched == null) continue;
- |$checkCondition
- |$numOutput.add(1);
- |${consume(ctx, input)}
- """.stripMargin
- } else {
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val found = ctx.freshName("found")
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// find matches from HashRelation
- |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
- |if ($matches == null) continue;
- |boolean $found = false;
- |while (!$found && $matches.hasNext()) {
- | UnsafeRow $matched = (UnsafeRow) $matches.next();
- | $checkCondition
- | $found = true;
- |}
- |if (!$found) continue;
- |$numOutput.add(1);
- |${consume(ctx, input)}
- """.stripMargin
- }
- }
-
- /**
- * Generates the code for anti join.
- */
- private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
- val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
- val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
- val (matched, checkCondition, _) = getJoinCondition(ctx, input)
- val numOutput = metricTerm(ctx, "numOutputRows")
-
- if (broadcastRelation.value.keyIsUnique) {
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// Check if the key has nulls.
- |if (!($anyNull)) {
- | // Check if the HashedRelation exists.
- | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
- | if ($matched != null) {
- | // Evaluate the condition.
- | $checkCondition
- | }
- |}
- |$numOutput.add(1);
- |${consume(ctx, input)}
- """.stripMargin
- } else {
- val matches = ctx.freshName("matches")
- val iteratorCls = classOf[Iterator[UnsafeRow]].getName
- val found = ctx.freshName("found")
- s"""
- |// generate join key for stream side
- |${keyEv.code}
- |// Check if the key has nulls.
- |if (!($anyNull)) {
- | // Check if the HashedRelation exists.
- | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value});
- | if ($matches != null) {
- | // Evaluate the condition.
- | boolean $found = false;
- | while (!$found && $matches.hasNext()) {
- | UnsafeRow $matched = (UnsafeRow) $matches.next();
- | $checkCondition
- | $found = true;
- | }
- | if ($found) continue;
- | }
- |}
- |$numOutput.add(1);
- |${consume(ctx, input)}
- """.stripMargin
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
new file mode 100644
index 0000000..51399e1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -0,0 +1,401 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.sql.types.LongType
+
+/**
+ * Performs an inner hash join of two child relations. When the output RDD of this operator is
+ * being constructed, a Spark job is asynchronously started to calculate the values for the
+ * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed
+ * relation is not shuffled.
+ */
+case class BroadcastHashJoinExec(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan)
+ extends BinaryExecNode with HashJoin with CodegenSupport {
+
+ override private[sql] lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
+
+ override def requiredChildDistribution: Seq[Distribution] = {
+ val mode = HashedRelationBroadcastMode(buildKeys)
+ buildSide match {
+ case BuildLeft =>
+ BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
+ case BuildRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+ }
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+
+ val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
+ streamedPlan.execute().mapPartitions { streamedIter =>
+ val hashed = broadcastRelation.value.asReadOnlyCopy()
+ TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
+ join(streamedIter, hashed, numOutputRows)
+ }
+ }
+
+ override def inputRDDs(): Seq[RDD[InternalRow]] = {
+ streamedPlan.asInstanceOf[CodegenSupport].inputRDDs()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ joinType match {
+ case Inner => codegenInner(ctx, input)
+ case LeftOuter | RightOuter => codegenOuter(ctx, input)
+ case LeftSemi => codegenSemi(ctx, input)
+ case LeftAnti => codegenAnti(ctx, input)
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastHashJoin should not take $x as the JoinType")
+ }
+ }
+
+ /**
+ * Returns a tuple of Broadcast of HashedRelation and the variable name for it.
+ */
+ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
+ // create a name for HashedRelation
+ val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
+ val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
+ val relationTerm = ctx.freshName("relation")
+ val clsName = broadcastRelation.value.getClass.getName
+ ctx.addMutableState(clsName, relationTerm,
+ s"""
+ | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy();
+ | incPeakExecutionMemory($relationTerm.estimatedSize());
+ """.stripMargin)
+ (broadcastRelation, relationTerm)
+ }
+
+ /**
+ * Returns the code for generating join key for stream side, and expression of whether the key
+ * has any null in it or not.
+ */
+ private def genStreamSideJoinKey(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (ExprCode, String) = {
+ ctx.currentVars = input
+ if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) {
+ // generate the join key as Long
+ val ev = streamedKeys.head.genCode(ctx)
+ (ev, ev.isNull)
+ } else {
+ // generate the join key as UnsafeRow
+ val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys)
+ (ev, s"${ev.value}.anyNull()")
+ }
+ }
+
+ /**
+ * Generates the code for variable of build side.
+ */
+ private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = {
+ ctx.currentVars = null
+ ctx.INPUT_ROW = matched
+ buildPlan.output.zipWithIndex.map { case (a, i) =>
+ val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx)
+ if (joinType == Inner) {
+ ev
+ } else {
+ // the variables are needed even there is no matched rows
+ val isNull = ctx.freshName("isNull")
+ val value = ctx.freshName("value")
+ val code = s"""
+ |boolean $isNull = true;
+ |${ctx.javaType(a.dataType)} $value = ${ctx.defaultValue(a.dataType)};
+ |if ($matched != null) {
+ | ${ev.code}
+ | $isNull = ${ev.isNull};
+ | $value = ${ev.value};
+ |}
+ """.stripMargin
+ ExprCode(code, isNull, value)
+ }
+ }
+ }
+
+ /**
+ * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi
+ * and Left Anti joins.
+ */
+ private def getJoinCondition(
+ ctx: CodegenContext,
+ input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = {
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+ // filter the output via condition
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
+ s"""
+ |$eval
+ |${ev.code}
+ |if (${ev.isNull} || !${ev.value}) continue;
+ """.stripMargin
+ } else {
+ ""
+ }
+ (matched, checkCondition, buildVars)
+ }
+
+ /**
+ * Generates the code for Inner join.
+ */
+ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched == null) continue;
+ |$checkCondition
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+
+ } else {
+ ctx.copyResult = true
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches == null) continue;
+ |while ($matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left or right outer join.
+ */
+ private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val matched = ctx.freshName("matched")
+ val buildVars = genBuildSideVars(ctx, matched)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ // filter the output via condition
+ val conditionPassed = ctx.freshName("conditionPassed")
+ val checkCondition = if (condition.isDefined) {
+ val expr = condition.get
+ // evaluate the variables from build side that used by condition
+ val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+ ctx.currentVars = input ++ buildVars
+ val ev =
+ BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx)
+ s"""
+ |boolean $conditionPassed = true;
+ |${eval.trim}
+ |${ev.code}
+ |if ($matched != null) {
+ | $conditionPassed = !${ev.isNull} && ${ev.value};
+ |}
+ """.stripMargin
+ } else {
+ s"final boolean $conditionPassed = true;"
+ }
+
+ val resultVars = buildSide match {
+ case BuildLeft => buildVars ++ input
+ case BuildRight => input ++ buildVars
+ }
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |${checkCondition.trim}
+ |if (!$conditionPassed) {
+ | $matched = null;
+ | // reset the variables those are already evaluated.
+ | ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")}
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, resultVars)}
+ """.stripMargin
+
+ } else {
+ ctx.copyResult = true
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |boolean $found = false;
+ |// the last iteration of this loop is to emit an empty row if there is no matched rows.
+ |while ($matches != null && $matches.hasNext() || !$found) {
+ | UnsafeRow $matched = $matches != null && $matches.hasNext() ?
+ | (UnsafeRow) $matches.next() : null;
+ | ${checkCondition.trim}
+ | if (!$conditionPassed) continue;
+ | $found = true;
+ | $numOutput.add(1);
+ | ${consume(ctx, resultVars)}
+ |}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for left semi join.
+ */
+ private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashedRelation
+ |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ |if ($matched == null) continue;
+ |$checkCondition
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// find matches from HashRelation
+ |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value});
+ |if ($matches == null) continue;
+ |boolean $found = false;
+ |while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | $found = true;
+ |}
+ |if (!$found) continue;
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ }
+ }
+
+ /**
+ * Generates the code for anti join.
+ */
+ private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
+ val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
+ val (matched, checkCondition, _) = getJoinCondition(ctx, input)
+ val numOutput = metricTerm(ctx, "numOutputRows")
+
+ if (broadcastRelation.value.keyIsUnique) {
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value});
+ | if ($matched != null) {
+ | // Evaluate the condition.
+ | $checkCondition
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ } else {
+ val matches = ctx.freshName("matches")
+ val iteratorCls = classOf[Iterator[UnsafeRow]].getName
+ val found = ctx.freshName("found")
+ s"""
+ |// generate join key for stream side
+ |${keyEv.code}
+ |// Check if the key has nulls.
+ |if (!($anyNull)) {
+ | // Check if the HashedRelation exists.
+ | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value});
+ | if ($matches != null) {
+ | // Evaluate the condition.
+ | boolean $found = false;
+ | while (!$found && $matches.hasNext()) {
+ | UnsafeRow $matched = (UnsafeRow) $matches.next();
+ | $checkCondition
+ | $found = true;
+ | }
+ | if ($found) continue;
+ | }
+ |}
+ |$numOutput.add(1);
+ |${consume(ctx, input)}
+ """.stripMargin
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
deleted file mode 100644
index 4ba710c..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ /dev/null
@@ -1,331 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark.broadcast.Broadcast
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.collection.{BitSet, CompactBuffer}
-
-case class BroadcastNestedLoopJoin(
- left: SparkPlan,
- right: SparkPlan,
- buildSide: BuildSide,
- joinType: JoinType,
- condition: Option[Expression]) extends BinaryNode {
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- /** BuildRight means the right relation <=> the broadcast relation. */
- private val (streamed, broadcast) = buildSide match {
- case BuildRight => (left, right)
- case BuildLeft => (right, left)
- }
-
- override def requiredChildDistribution: Seq[Distribution] = buildSide match {
- case BuildLeft =>
- BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil
- case BuildRight =>
- UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
- }
-
- private[this] def genResultProjection: InternalRow => InternalRow = {
- if (joinType == LeftSemi) {
- UnsafeProjection.create(output, output)
- } else {
- // Always put the stream side on left to simplify implementation
- // both of left and right side could be null
- UnsafeProjection.create(
- output, (streamed.output ++ broadcast.output).map(_.withNullability(true)))
- }
- }
-
- override def outputPartitioning: Partitioning = streamed.outputPartitioning
-
- override def output: Seq[Attribute] = {
- joinType match {
- case Inner =>
- left.output ++ right.output
- case LeftOuter =>
- left.output ++ right.output.map(_.withNullability(true))
- case RightOuter =>
- left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter =>
- left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case LeftExistence(_) =>
- left.output
- case x =>
- throw new IllegalArgumentException(
- s"BroadcastNestedLoopJoin should not take $x as the JoinType")
- }
- }
-
- @transient private lazy val boundCondition = {
- if (condition.isDefined) {
- newPredicate(condition.get, streamed.output ++ broadcast.output)
- } else {
- (r: InternalRow) => true
- }
- }
-
- /**
- * The implementation for InnerJoin.
- */
- private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
- streamed.execute().mapPartitionsInternal { streamedIter =>
- val buildRows = relation.value
- val joinedRow = new JoinedRow
-
- streamedIter.flatMap { streamedRow =>
- val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r))
- if (condition.isDefined) {
- joinedRows.filter(boundCondition)
- } else {
- joinedRows
- }
- }
- }
- }
-
- /**
- * The implementation for these joins:
- *
- * LeftOuter with BuildRight
- * RightOuter with BuildLeft
- */
- private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
- streamed.execute().mapPartitionsInternal { streamedIter =>
- val buildRows = relation.value
- val joinedRow = new JoinedRow
- val nulls = new GenericMutableRow(broadcast.output.size)
-
- // Returns an iterator to avoid copy the rows.
- new Iterator[InternalRow] {
- // current row from stream side
- private var streamRow: InternalRow = null
- // have found a match for current row or not
- private var foundMatch: Boolean = false
- // the matched result row
- private var resultRow: InternalRow = null
- // the next index of buildRows to try
- private var nextIndex: Int = 0
-
- private def findNextMatch(): Boolean = {
- if (streamRow == null) {
- if (!streamedIter.hasNext) {
- return false
- }
- streamRow = streamedIter.next()
- nextIndex = 0
- foundMatch = false
- }
- while (nextIndex < buildRows.length) {
- resultRow = joinedRow(streamRow, buildRows(nextIndex))
- nextIndex += 1
- if (boundCondition(resultRow)) {
- foundMatch = true
- return true
- }
- }
- if (!foundMatch) {
- resultRow = joinedRow(streamRow, nulls)
- streamRow = null
- true
- } else {
- resultRow = null
- streamRow = null
- findNextMatch()
- }
- }
-
- override def hasNext(): Boolean = {
- resultRow != null || findNextMatch()
- }
- override def next(): InternalRow = {
- val r = resultRow
- resultRow = null
- r
- }
- }
- }
- }
-
- /**
- * The implementation for these joins:
- *
- * LeftSemi with BuildRight
- * Anti with BuildRight
- */
- private def leftExistenceJoin(
- relation: Broadcast[Array[InternalRow]],
- exists: Boolean): RDD[InternalRow] = {
- assert(buildSide == BuildRight)
- streamed.execute().mapPartitionsInternal { streamedIter =>
- val buildRows = relation.value
- val joinedRow = new JoinedRow
-
- if (condition.isDefined) {
- streamedIter.filter(l =>
- buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
- )
- } else if (buildRows.nonEmpty == exists) {
- streamedIter
- } else {
- Iterator.empty
- }
- }
- }
-
- /**
- * The implementation for these joins:
- *
- * LeftOuter with BuildLeft
- * RightOuter with BuildRight
- * FullOuter
- * LeftSemi with BuildLeft
- * Anti with BuildLeft
- */
- private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
- /** All rows that either match both-way, or rows from streamed joined with nulls. */
- val streamRdd = streamed.execute()
-
- val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
- val buildRows = relation.value
- val matched = new BitSet(buildRows.length)
- val joinedRow = new JoinedRow
-
- streamedIter.foreach { streamedRow =>
- var i = 0
- while (i < buildRows.length) {
- if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
- matched.set(i)
- }
- i += 1
- }
- }
- Seq(matched).toIterator
- }
-
- val matchedBroadcastRows = matchedBuildRows.fold(
- new BitSet(relation.value.length)
- )(_ | _)
-
- if (joinType == LeftSemi) {
- assert(buildSide == BuildLeft)
- val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val rel = relation.value
- while (i < rel.length) {
- if (matchedBroadcastRows.get(i)) {
- buf += rel(i).copy()
- }
- i += 1
- }
- return sparkContext.makeRDD(buf)
- }
-
- val notMatchedBroadcastRows: Seq[InternalRow] = {
- val nulls = new GenericMutableRow(streamed.output.size)
- val buf: CompactBuffer[InternalRow] = new CompactBuffer()
- var i = 0
- val buildRows = relation.value
- val joinedRow = new JoinedRow
- joinedRow.withLeft(nulls)
- while (i < buildRows.length) {
- if (!matchedBroadcastRows.get(i)) {
- buf += joinedRow.withRight(buildRows(i)).copy()
- }
- i += 1
- }
- buf
- }
-
- if (joinType == LeftAnti) {
- return sparkContext.makeRDD(notMatchedBroadcastRows)
- }
-
- val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
- val buildRows = relation.value
- val joinedRow = new JoinedRow
- val nulls = new GenericMutableRow(broadcast.output.size)
-
- streamedIter.flatMap { streamedRow =>
- var i = 0
- var foundMatch = false
- val matchedRows = new CompactBuffer[InternalRow]
-
- while (i < buildRows.length) {
- if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
- matchedRows += joinedRow.copy()
- foundMatch = true
- }
- i += 1
- }
-
- if (!foundMatch && joinType == FullOuter) {
- matchedRows += joinedRow(streamedRow, nulls).copy()
- }
- matchedRows.iterator
- }
- }
-
- sparkContext.union(
- matchedStreamRows,
- sparkContext.makeRDD(notMatchedBroadcastRows)
- )
- }
-
- protected override def doExecute(): RDD[InternalRow] = {
- val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
-
- val resultRdd = (joinType, buildSide) match {
- case (Inner, _) =>
- innerJoin(broadcastedRelation)
- case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
- outerJoin(broadcastedRelation)
- case (LeftSemi, BuildRight) =>
- leftExistenceJoin(broadcastedRelation, exists = true)
- case (LeftAnti, BuildRight) =>
- leftExistenceJoin(broadcastedRelation, exists = false)
- case _ =>
- /**
- * LeftOuter with BuildLeft
- * RightOuter with BuildRight
- * FullOuter
- * LeftSemi with BuildLeft
- * Anti with BuildLeft
- */
- defaultJoin(broadcastedRelation)
- }
-
- val numOutputRows = longMetric("numOutputRows")
- resultRdd.mapPartitionsInternal { iter =>
- val resultProj = genResultProjection
- iter.map { r =>
- numOutputRows += 1
- resultProj(r)
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
new file mode 100644
index 0000000..51afa00
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -0,0 +1,331 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.util.collection.{BitSet, CompactBuffer}
+
+case class BroadcastNestedLoopJoinExec(
+ left: SparkPlan,
+ right: SparkPlan,
+ buildSide: BuildSide,
+ joinType: JoinType,
+ condition: Option[Expression]) extends BinaryExecNode {
+
+ override private[sql] lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ /** BuildRight means the right relation <=> the broadcast relation. */
+ private val (streamed, broadcast) = buildSide match {
+ case BuildRight => (left, right)
+ case BuildLeft => (right, left)
+ }
+
+ override def requiredChildDistribution: Seq[Distribution] = buildSide match {
+ case BuildLeft =>
+ BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil
+ case BuildRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
+ }
+
+ private[this] def genResultProjection: InternalRow => InternalRow = {
+ if (joinType == LeftSemi) {
+ UnsafeProjection.create(output, output)
+ } else {
+ // Always put the stream side on left to simplify implementation
+ // both of left and right side could be null
+ UnsafeProjection.create(
+ output, (streamed.output ++ broadcast.output).map(_.withNullability(true)))
+ }
+ }
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def output: Seq[Attribute] = {
+ joinType match {
+ case Inner =>
+ left.output ++ right.output
+ case LeftOuter =>
+ left.output ++ right.output.map(_.withNullability(true))
+ case RightOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output
+ case FullOuter =>
+ left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+ case LeftExistence(_) =>
+ left.output
+ case x =>
+ throw new IllegalArgumentException(
+ s"BroadcastNestedLoopJoin should not take $x as the JoinType")
+ }
+ }
+
+ @transient private lazy val boundCondition = {
+ if (condition.isDefined) {
+ newPredicate(condition.get, streamed.output ++ broadcast.output)
+ } else {
+ (r: InternalRow) => true
+ }
+ }
+
+ /**
+ * The implementation for InnerJoin.
+ */
+ private def innerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ streamed.execute().mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+
+ streamedIter.flatMap { streamedRow =>
+ val joinedRows = buildRows.iterator.map(r => joinedRow(streamedRow, r))
+ if (condition.isDefined) {
+ joinedRows.filter(boundCondition)
+ } else {
+ joinedRows
+ }
+ }
+ }
+ }
+
+ /**
+ * The implementation for these joins:
+ *
+ * LeftOuter with BuildRight
+ * RightOuter with BuildLeft
+ */
+ private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ streamed.execute().mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ val nulls = new GenericMutableRow(broadcast.output.size)
+
+ // Returns an iterator to avoid copy the rows.
+ new Iterator[InternalRow] {
+ // current row from stream side
+ private var streamRow: InternalRow = null
+ // have found a match for current row or not
+ private var foundMatch: Boolean = false
+ // the matched result row
+ private var resultRow: InternalRow = null
+ // the next index of buildRows to try
+ private var nextIndex: Int = 0
+
+ private def findNextMatch(): Boolean = {
+ if (streamRow == null) {
+ if (!streamedIter.hasNext) {
+ return false
+ }
+ streamRow = streamedIter.next()
+ nextIndex = 0
+ foundMatch = false
+ }
+ while (nextIndex < buildRows.length) {
+ resultRow = joinedRow(streamRow, buildRows(nextIndex))
+ nextIndex += 1
+ if (boundCondition(resultRow)) {
+ foundMatch = true
+ return true
+ }
+ }
+ if (!foundMatch) {
+ resultRow = joinedRow(streamRow, nulls)
+ streamRow = null
+ true
+ } else {
+ resultRow = null
+ streamRow = null
+ findNextMatch()
+ }
+ }
+
+ override def hasNext(): Boolean = {
+ resultRow != null || findNextMatch()
+ }
+ override def next(): InternalRow = {
+ val r = resultRow
+ resultRow = null
+ r
+ }
+ }
+ }
+ }
+
+ /**
+ * The implementation for these joins:
+ *
+ * LeftSemi with BuildRight
+ * Anti with BuildRight
+ */
+ private def leftExistenceJoin(
+ relation: Broadcast[Array[InternalRow]],
+ exists: Boolean): RDD[InternalRow] = {
+ assert(buildSide == BuildRight)
+ streamed.execute().mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+
+ if (condition.isDefined) {
+ streamedIter.filter(l =>
+ buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists
+ )
+ } else if (buildRows.nonEmpty == exists) {
+ streamedIter
+ } else {
+ Iterator.empty
+ }
+ }
+ }
+
+ /**
+ * The implementation for these joins:
+ *
+ * LeftOuter with BuildLeft
+ * RightOuter with BuildRight
+ * FullOuter
+ * LeftSemi with BuildLeft
+ * Anti with BuildLeft
+ */
+ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = {
+ /** All rows that either match both-way, or rows from streamed joined with nulls. */
+ val streamRdd = streamed.execute()
+
+ val matchedBuildRows = streamRdd.mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val matched = new BitSet(buildRows.length)
+ val joinedRow = new JoinedRow
+
+ streamedIter.foreach { streamedRow =>
+ var i = 0
+ while (i < buildRows.length) {
+ if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
+ matched.set(i)
+ }
+ i += 1
+ }
+ }
+ Seq(matched).toIterator
+ }
+
+ val matchedBroadcastRows = matchedBuildRows.fold(
+ new BitSet(relation.value.length)
+ )(_ | _)
+
+ if (joinType == LeftSemi) {
+ assert(buildSide == BuildLeft)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val rel = relation.value
+ while (i < rel.length) {
+ if (matchedBroadcastRows.get(i)) {
+ buf += rel(i).copy()
+ }
+ i += 1
+ }
+ return sparkContext.makeRDD(buf)
+ }
+
+ val notMatchedBroadcastRows: Seq[InternalRow] = {
+ val nulls = new GenericMutableRow(streamed.output.size)
+ val buf: CompactBuffer[InternalRow] = new CompactBuffer()
+ var i = 0
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ joinedRow.withLeft(nulls)
+ while (i < buildRows.length) {
+ if (!matchedBroadcastRows.get(i)) {
+ buf += joinedRow.withRight(buildRows(i)).copy()
+ }
+ i += 1
+ }
+ buf
+ }
+
+ if (joinType == LeftAnti) {
+ return sparkContext.makeRDD(notMatchedBroadcastRows)
+ }
+
+ val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter =>
+ val buildRows = relation.value
+ val joinedRow = new JoinedRow
+ val nulls = new GenericMutableRow(broadcast.output.size)
+
+ streamedIter.flatMap { streamedRow =>
+ var i = 0
+ var foundMatch = false
+ val matchedRows = new CompactBuffer[InternalRow]
+
+ while (i < buildRows.length) {
+ if (boundCondition(joinedRow(streamedRow, buildRows(i)))) {
+ matchedRows += joinedRow.copy()
+ foundMatch = true
+ }
+ i += 1
+ }
+
+ if (!foundMatch && joinType == FullOuter) {
+ matchedRows += joinedRow(streamedRow, nulls).copy()
+ }
+ matchedRows.iterator
+ }
+ }
+
+ sparkContext.union(
+ matchedStreamRows,
+ sparkContext.makeRDD(notMatchedBroadcastRows)
+ )
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
+
+ val resultRdd = (joinType, buildSide) match {
+ case (Inner, _) =>
+ innerJoin(broadcastedRelation)
+ case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) =>
+ outerJoin(broadcastedRelation)
+ case (LeftSemi, BuildRight) =>
+ leftExistenceJoin(broadcastedRelation, exists = true)
+ case (LeftAnti, BuildRight) =>
+ leftExistenceJoin(broadcastedRelation, exists = false)
+ case _ =>
+ /**
+ * LeftOuter with BuildLeft
+ * RightOuter with BuildRight
+ * FullOuter
+ * LeftSemi with BuildLeft
+ * Anti with BuildLeft
+ */
+ defaultJoin(broadcastedRelation)
+ }
+
+ val numOutputRows = longMetric("numOutputRows")
+ resultRdd.mapPartitionsInternal { iter =>
+ val resultProj = genResultProjection
+ iter.map { r =>
+ numOutputRows += 1
+ resultProj(r)
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
deleted file mode 100644
index b1de52b..0000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.joins
-
-import org.apache.spark._
-import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
-import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.CompletionIterator
-import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
-
-/**
- * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
- * will be much faster than building the right partition for every row in left RDD, it also
- * materialize the right RDD (in case of the right RDD is nondeterministic).
- */
-private[spark]
-class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
- extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
-
- override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
- // We will not sort the rows, so prefixComparator and recordComparator are null.
- val sorter = UnsafeExternalSorter.create(
- context.taskMemoryManager(),
- SparkEnv.get.blockManager,
- SparkEnv.get.serializerManager,
- context,
- null,
- null,
- 1024,
- SparkEnv.get.memoryManager.pageSizeBytes,
- false)
-
- val partition = split.asInstanceOf[CartesianPartition]
- for (y <- rdd2.iterator(partition.s2, context)) {
- sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0)
- }
-
- // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
- def createIter(): Iterator[UnsafeRow] = {
- val iter = sorter.getIterator
- val unsafeRow = new UnsafeRow(numFieldsOfRight)
- new Iterator[UnsafeRow] {
- override def hasNext: Boolean = {
- iter.hasNext
- }
- override def next(): UnsafeRow = {
- iter.loadNext()
- unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
- unsafeRow
- }
- }
- }
-
- val resultIter =
- for (x <- rdd1.iterator(partition.s1, context);
- y <- createIter()) yield (x, y)
- CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
- resultIter, sorter.cleanupResources)
- }
-}
-
-
-case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
- override def output: Seq[Attribute] = left.output ++ right.output
-
- override private[sql] lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
-
- protected override def doExecute(): RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
-
- val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
- val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
-
- val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
- pair.mapPartitionsInternal { iter =>
- val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
- iter.map { r =>
- numOutputRows += 1
- joiner.join(r._1, r._2)
- }
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/d7d0cad0/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
new file mode 100644
index 0000000..3ce7c0e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.apache.spark._
+import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.metric.SQLMetrics
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
+
+/**
+ * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD,
+ * will be much faster than building the right partition for every row in left RDD, it also
+ * materialize the right RDD (in case of the right RDD is nondeterministic).
+ */
+private[spark]
+class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int)
+ extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) {
+
+ override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = {
+ // We will not sort the rows, so prefixComparator and recordComparator are null.
+ val sorter = UnsafeExternalSorter.create(
+ context.taskMemoryManager(),
+ SparkEnv.get.blockManager,
+ SparkEnv.get.serializerManager,
+ context,
+ null,
+ null,
+ 1024,
+ SparkEnv.get.memoryManager.pageSizeBytes,
+ false)
+
+ val partition = split.asInstanceOf[CartesianPartition]
+ for (y <- rdd2.iterator(partition.s2, context)) {
+ sorter.insertRecord(y.getBaseObject, y.getBaseOffset, y.getSizeInBytes, 0)
+ }
+
+ // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow]
+ def createIter(): Iterator[UnsafeRow] = {
+ val iter = sorter.getIterator
+ val unsafeRow = new UnsafeRow(numFieldsOfRight)
+ new Iterator[UnsafeRow] {
+ override def hasNext: Boolean = {
+ iter.hasNext
+ }
+ override def next(): UnsafeRow = {
+ iter.loadNext()
+ unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength)
+ unsafeRow
+ }
+ }
+ }
+
+ val resultIter =
+ for (x <- rdd1.iterator(partition.s1, context);
+ y <- createIter()) yield (x, y)
+ CompletionIterator[(UnsafeRow, UnsafeRow), Iterator[(UnsafeRow, UnsafeRow)]](
+ resultIter, sorter.cleanupResources)
+ }
+}
+
+
+case class CartesianProductExec(left: SparkPlan, right: SparkPlan) extends BinaryExecNode {
+ override def output: Seq[Attribute] = left.output ++ right.output
+
+ override private[sql] lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+
+ val leftResults = left.execute().asInstanceOf[RDD[UnsafeRow]]
+ val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
+
+ val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
+ pair.mapPartitionsInternal { iter =>
+ val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
+ iter.map { r =>
+ numOutputRows += 1
+ joiner.join(r._1, r._2)
+ }
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org