You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by cloud-fan <gi...@git.apache.org> on 2018/10/02 13:58:27 UTC
[GitHub] spark pull request #10989: [SPARK-12798] [SQL] generated BroadcastHashJoin
Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/10989#discussion_r221961271
--- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala ---
@@ -117,6 +120,87 @@ case class BroadcastHashJoin(
hashJoin(streamedIter, numStreamedRows, hashedRelation, numOutputRows)
}
}
+
+ // the term for hash relation
+ private var relationTerm: String = _
+
+ override def upstream(): RDD[InternalRow] = {
+ streamedPlan.asInstanceOf[CodegenSupport].upstream()
+ }
+
+ override def doProduce(ctx: CodegenContext): String = {
+ // create a name for HashRelation
+ val broadcastRelation = Await.result(broadcastFuture, timeout)
+ val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
+ relationTerm = ctx.freshName("relation")
+ // TODO: create specialized HashRelation for single join key
+ val clsName = classOf[UnsafeHashedRelation].getName
+ ctx.addMutableState(clsName, relationTerm,
+ s"""
+ | $relationTerm = ($clsName) $broadcast.value();
+ | incPeakExecutionMemory($relationTerm.getUnsafeSize());
+ """.stripMargin)
+
+ s"""
+ | ${streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ """.stripMargin
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ // generate the key as UnsafeRow
+ ctx.currentVars = input
+ val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output))
+ val keyVal = GenerateUnsafeProjection.createCode(ctx, keyExpr)
+ val keyTerm = keyVal.value
+ val anyNull = if (keyExpr.exists(_.nullable)) s"$keyTerm.anyNull()" else "false"
+
+ // find the matches from HashedRelation
+ val matches = ctx.freshName("matches")
+ val bufferType = classOf[CompactBuffer[UnsafeRow]].getName
+ val i = ctx.freshName("i")
+ val size = ctx.freshName("size")
+ val row = ctx.freshName("row")
+
+ // create variables for output
+ ctx.currentVars = null
+ ctx.INPUT_ROW = row
+ val buildColumns = buildPlan.output.zipWithIndex.map { case (a, i) =>
+ BoundReference(i, a.dataType, a.nullable).gen(ctx)
+ }
+ val resultVars = buildSide match {
+ case BuildLeft => buildColumns ++ input
+ case BuildRight => input ++ buildColumns
+ }
+
+ val ouputCode = if (condition.isDefined) {
+ // filter the output via condition
+ ctx.currentVars = resultVars
+ val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+ s"""
+ | ${ev.code}
+ | if (!${ev.isNull} && ${ev.value}) {
+ | ${consume(ctx, resultVars)}
+ | }
+ """.stripMargin
+ } else {
+ consume(ctx, resultVars)
+ }
+
+ s"""
+ | // generate join key
+ | ${keyVal.code}
+ | // find matches from HashRelation
+ | $bufferType $matches = $anyNull ? null : ($bufferType) $relationTerm.get($keyTerm);
+ | if ($matches != null) {
+ | int $size = $matches.size();
+ | for (int $i = 0; $i < $size; $i++) {
--- End diff --
I don't see a strong reason that we can't interrupt this loop. We can make `i` a global variable for example.
I don't mean to change anything, but just to verify my understanding. Also cc @viirya @mgaido91
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org