You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by GitBox <gi...@apache.org> on 2019/01/13 22:57:19 UTC
[spark] Diff for: [GitHub] asfgit closed pull request #23392:
[SPARK-26450][SQL] Avoid rebuilding map of schema for every column in
projection
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index ea8c369ee49ed..7ae5924b20faf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -86,4 +86,13 @@ object BindReferences extends Logging {
}
}.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible.
}
+
+ /**
+ * A helper function to bind given expressions to an input schema.
+ */
+ def bindReferences[A <: Expression](
+ expressions: Seq[A],
+ input: AttributeSeq): Seq[A] = {
+ expressions.map(BindReferences.bindReference(_, input))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 122a564da61be..5c8aa4e2e9d83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
@@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
*/
class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
- this(toBoundExprs(expressions, inputSchema))
+ this(bindReferences(expressions, inputSchema))
private[this] val buffer = new Array[Any](expressions.size)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index b48f7ba655b2f..eaaf94baac216 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
@@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
*/
class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
- this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+ this(bindReferences(expressions, inputSchema))
override def initialize(partitionIndex: Int): Unit = {
expressions.foreach(_.foreach {
@@ -99,7 +100,7 @@ object MutableProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
- create(toBoundExprs(exprs, inputSchema))
+ create(bindReferences(exprs, inputSchema))
}
}
@@ -162,7 +163,7 @@ object UnsafeProjection
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
- create(toBoundExprs(exprs, inputSchema))
+ create(bindReferences(exprs, inputSchema))
}
}
@@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio
* `inputSchema`.
*/
def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
- create(toBoundExprs(exprs, inputSchema))
+ create(bindReferences(exprs, inputSchema))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d588e7f081303..838bd1c679e4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
// MutableProjection is not accessible in Java
@@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
in.map(ExpressionCanonicalizer.execute)
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
- in.map(BindReferences.bindReference(_, inputSchema))
+ bindReferences(in, inputSchema)
def generate(
expressions: Seq[Expression],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 283fd2a6e9383..b66b80ad31dc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
@@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
- in.map(BindReferences.bindReference(_, inputSchema))
+ bindReferences(in, inputSchema)
/**
* Creates a code gen ordering for sorting this schema, in ascending order.
@@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
extends Ordering[InternalRow] with KryoSerializable {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
- this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+ this(bindReferences(ordering, inputSchema))
@transient
private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 39778661d1c48..e285398ba1958 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
@@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
in.map(ExpressionCanonicalizer.execute)
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
- in.map(BindReferences.bindReference(_, inputSchema))
+ bindReferences(in, inputSchema)
private def createCodeForStruct(
ctx: CodegenContext,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 0ecd0de8d8203..fb1d8a3c8e739 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.codegen
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types._
@@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
in.map(ExpressionCanonicalizer.execute)
protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
- in.map(BindReferences.bindReference(_, inputSchema))
+ bindReferences(in, inputSchema)
def generate(
expressions: Seq[Expression],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
index e24a3de3cfdbe..c8d667143f452 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.types._
@@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
- this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+ this(bindReferences(ordering, inputSchema))
def compare(a: InternalRow, b: InternalRow): Int = {
var i = 0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index bf18e8bcb52df..932c364737249 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -85,13 +85,6 @@ package object expressions {
override def apply(row: InternalRow): InternalRow = row
}
- /**
- * A helper function to bind given expressions to an input schema.
- */
- def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
- exprs.map(BindReferences.bindReference(_, inputSchema))
- }
-
/**
* Helper functions for working with `Seq[Attribute]`.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 5b4edf5136e3f..85f49140a4b41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -145,11 +145,12 @@ case class ExpandExec(
// Part 1: declare variables for each column
// If a column has the same value for all output rows, then we also generate its computation
// right after declaration. Otherwise its value is computed in the part 2.
+ lazy val attributeSeq: AttributeSeq = child.output
val outputColumns = output.indices.map { col =>
val firstExpr = projections.head(col)
if (sameOutput(col)) {
// This column is the same across all output rows. Just generate code for it here.
- BindReferences.bindReference(firstExpr, child.output).genCode(ctx)
+ BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
} else {
val isNull = ctx.freshName("isNull")
val value = ctx.freshName("value")
@@ -170,7 +171,7 @@ case class ExpandExec(
var updateCode = ""
for (col <- exprs.indices) {
if (!sameOutput(col)) {
- val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx)
+ val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
updateCode +=
s"""
|${ev.code}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 98c4a51299958..a1fb23d621d49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -77,6 +77,7 @@ abstract class AggregationIterator(
val expressionsLength = expressions.length
val functions = new Array[AggregateFunction](expressionsLength)
var i = 0
+ val inputAttributeSeq: AttributeSeq = inputAttributes
while (i < expressionsLength) {
val func = expressions(i).aggregateFunction
val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
@@ -86,7 +87,7 @@ abstract class AggregationIterator(
// this function is Partial or Complete because we will call eval of this
// function's children in the update method of this aggregate function.
// Those eval calls require BoundReferences to work.
- BindReferences.bindReference(func, inputAttributes)
+ BindReferences.bindReference(func, inputAttributeSeq)
case _ =>
// We only need to set inputBufferOffset for aggregate functions with mode
// PartialMerge and Final.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 4827f838fc514..28801774418a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -199,15 +200,13 @@ case class HashAggregateExec(
val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
// evaluate aggregate results
ctx.currentVars = bufVars
- val aggResults = functions.map(_.evaluateExpression).map { e =>
- BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
- }
+ val aggResults = bindReferences(
+ functions.map(_.evaluateExpression),
+ aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// evaluate result expressions
ctx.currentVars = aggResults
- val resultVars = resultExpressions.map { e =>
- BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
- }
+ val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
(resultVars, s"""
|$evaluateAggResults
|${evaluateVariables(resultVars)}
@@ -264,7 +263,7 @@ case class HashAggregateExec(
}
}
ctx.currentVars = bufVars ++ input
- val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
+ val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -456,16 +455,16 @@ case class HashAggregateExec(
val evaluateBufferVars = evaluateVariables(bufferVars)
// evaluate the aggregation result
ctx.currentVars = bufferVars
- val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
- BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
- }
+ val aggResults = bindReferences(
+ declFunctions.map(_.evaluateExpression),
+ aggregateBufferAttributes).map(_.genCode(ctx))
val evaluateAggResults = evaluateVariables(aggResults)
// generate the final result
ctx.currentVars = keyVars ++ aggResults
val inputAttrs = groupingAttributes ++ aggregateAttributes
- val resultVars = resultExpressions.map { e =>
- BindReferences.bindReference(e, inputAttrs).genCode(ctx)
- }
+ val resultVars = bindReferences[Expression](
+ resultExpressions,
+ inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateBufferVars
@@ -494,9 +493,9 @@ case class HashAggregateExec(
ctx.currentVars = keyVars ++ resultBufferVars
val inputAttrs = resultExpressions.map(_.toAttribute)
- val resultVars = resultExpressions.map { e =>
- BindReferences.bindReference(e, inputAttrs).genCode(ctx)
- }
+ val resultVars = bindReferences[Expression](
+ resultExpressions,
+ inputAttrs).map(_.genCode(ctx))
s"""
$evaluateKeyVars
$evaluateResultBufferVars
@@ -506,9 +505,9 @@ case class HashAggregateExec(
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
ctx.currentVars = null
- val eval = resultExpressions.map{ e =>
- BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
- }
+ val eval = bindReferences[Expression](
+ resultExpressions,
+ groupingAttributes).map(_.genCode(ctx))
consume(ctx, eval)
}
ctx.addNewFunction(funcName,
@@ -730,9 +729,9 @@ case class HashAggregateExec(
private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// create grouping key
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
- ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+ ctx, bindReferences[Expression](groupingExpressions, child.output))
val fastRowKeys = ctx.generateExpressions(
- groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+ bindReferences[Expression](groupingExpressions, child.output))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
val fastRowBuffer = ctx.freshName("fastAggBuffer")
@@ -825,7 +824,7 @@ case class HashAggregateExec(
val updateRowInRegularHashMap: String = {
ctx.INPUT_ROW = unsafeRowBuffer
- val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+ val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -849,7 +848,7 @@ case class HashAggregateExec(
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
ctx.INPUT_ROW = fastRowBuffer
- val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+ val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
val effectiveCodes = subExprs.codes.mkString("\n")
val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 09effe087e195..7e87a150c0a4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
}
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
- val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
+ val exprs = bindReferences[Expression](projectList, child.output)
val resultVars = exprs.map(_.genCode(ctx))
// Evaluation of non-deterministic expressions can't be deferred.
val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 774fe38f5c2e6..260ad97506a85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
@@ -145,9 +146,8 @@ object FileFormatWriter extends Logging {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
- val orderingExpr = requiredOrdering
- .map(SortOrder(_, Ascending))
- .map(BindReferences.bindReference(_, outputSpec.outputColumns))
+ val orderingExpr = bindReferences(
+ requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
SortExec(
orderingExpr,
global = false,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 1aef5f6864263..5ee4c7ffb1911 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
@@ -63,9 +64,8 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
- val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
- val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
- .map(BindReferences.bindReference(_, right.output))
+ val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
+ val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index d7d3f6d6078b4..f829f07e80720 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans._
@@ -393,7 +394,7 @@ case class SortMergeJoinExec(
input: Seq[Attribute]): Seq[ExprCode] = {
ctx.INPUT_ROW = row
ctx.currentVars = null
- keys.map(BindReferences.bindReference(_, input).genCode(ctx))
+ bindReferences(keys, input).map(_.genCode(ctx))
}
private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 156002ef58fbe..5bf34558fe493 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -21,6 +21,7 @@ import java.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
@@ -89,9 +90,8 @@ private[window] final class OffsetWindowFunctionFrame(
private[this] val projection = {
// Collect the expressions and bind them.
val inputAttrs = inputSchema.map(_.withNullability(true))
- val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
- BindReferences.bindReference(e.input, inputAttrs)
- }
+ val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
+ expressions.toSeq.map(_.input), inputAttrs)
// Create the projection.
newMutableProjection(boundExpressions, Nil).target(target)
@@ -100,7 +100,7 @@ private[window] final class OffsetWindowFunctionFrame(
/** Create the projection used when the offset row DOES NOT exists. */
private[this] val fillDefaultValue = {
// Collect the expressions and bind them.
- val inputAttrs = inputSchema.map(_.withNullability(true))
+ val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
if (e.default == null || e.default.foldable && e.default.eval() == null) {
// The default value is null.
With regards,
Apache Git Services
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org