You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by GitBox <gi...@apache.org> on 2019/01/13 22:57:19 UTC

[spark] Diff for: [GitHub] asfgit closed pull request #23392: [SPARK-26450][SQL] Avoid rebuilding map of schema for every column in projection

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index ea8c369ee49ed..7ae5924b20faf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -86,4 +86,13 @@ object BindReferences extends Logging {
       }
     }.asInstanceOf[A] // Kind of a hack, but safe.  TODO: Tighten return type when possible.
   }
+
+  /**
+   * A helper function to bind given expressions to an input schema.
+   */
+  def bindReferences[A <: Expression](
+      expressions: Seq[A],
+      input: AttributeSeq): Seq[A] = {
+    expressions.map(BindReferences.bindReference(_, input))
+  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
index 122a564da61be..5c8aa4e2e9d83 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InterpretedMutableProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 
@@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
  */
 class InterpretedMutableProjection(expressions: Seq[Expression]) extends MutableProjection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
-    this(toBoundExprs(expressions, inputSchema))
+    this(bindReferences(expressions, inputSchema))
 
   private[this] val buffer = new Array[Any](expressions.size)
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index b48f7ba655b2f..eaaf94baac216 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection, GenerateUnsafeProjection}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
@@ -30,7 +31,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
  */
 class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
-    this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(expressions, inputSchema))
 
   override def initialize(partitionIndex: Int): Unit = {
     expressions.foreach(_.foreach {
@@ -99,7 +100,7 @@ object MutableProjection
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): MutableProjection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
 
@@ -162,7 +163,7 @@ object UnsafeProjection
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
 
@@ -203,6 +204,6 @@ object SafeProjection extends CodeGeneratorWithInterpretedFallback[Seq[Expressio
    * `inputSchema`.
    */
   def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Projection = {
-    create(toBoundExprs(exprs, inputSchema))
+    create(bindReferences(exprs, inputSchema))
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d588e7f081303..838bd1c679e4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 
 // MutableProjection is not accessible in Java
@@ -35,7 +36,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   def generate(
       expressions: Seq[Expression],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
index 283fd2a6e9383..b66b80ad31dc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala
@@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.io.{Input, Output}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
@@ -46,7 +47,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
     in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder])
 
   protected def bind(in: Seq[SortOrder], inputSchema: Seq[Attribute]): Seq[SortOrder] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   /**
    * Creates a code gen ordering for sorting this schema, in ascending order.
@@ -188,7 +189,7 @@ class LazilyGeneratedOrdering(val ordering: Seq[SortOrder])
   extends Ordering[InternalRow] with KryoSerializable {
 
   def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
-    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(ordering, inputSchema))
 
   @transient
   private[this] var generatedOrdering = GenerateOrdering.generate(ordering)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 39778661d1c48..e285398ba1958 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
@@ -41,7 +42,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   private def createCodeForStruct(
       ctx: CodegenContext,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 0ecd0de8d8203..fb1d8a3c8e739 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.types._
 
@@ -317,7 +318,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
     in.map(ExpressionCanonicalizer.execute)
 
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] =
-    in.map(BindReferences.bindReference(_, inputSchema))
+    bindReferences(in, inputSchema)
 
   def generate(
       expressions: Seq[Expression],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
index e24a3de3cfdbe..c8d667143f452 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions
 
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.types._
 
 
@@ -27,7 +28,7 @@ import org.apache.spark.sql.types._
 class InterpretedOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] {
 
   def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) =
-    this(ordering.map(BindReferences.bindReference(_, inputSchema)))
+    this(bindReferences(ordering, inputSchema))
 
   def compare(a: InternalRow, b: InternalRow): Int = {
     var i = 0
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index bf18e8bcb52df..932c364737249 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -85,13 +85,6 @@ package object expressions  {
     override def apply(row: InternalRow): InternalRow = row
   }
 
-  /**
-   * A helper function to bind given expressions to an input schema.
-   */
-  def toBoundExprs(exprs: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = {
-    exprs.map(BindReferences.bindReference(_, inputSchema))
-  }
-
   /**
    * Helper functions for working with `Seq[Attribute]`.
    */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index 5b4edf5136e3f..85f49140a4b41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -145,11 +145,12 @@ case class ExpandExec(
     // Part 1: declare variables for each column
     // If a column has the same value for all output rows, then we also generate its computation
     // right after declaration. Otherwise its value is computed in the part 2.
+    lazy val attributeSeq: AttributeSeq = child.output
     val outputColumns = output.indices.map { col =>
       val firstExpr = projections.head(col)
       if (sameOutput(col)) {
         // This column is the same across all output rows. Just generate code for it here.
-        BindReferences.bindReference(firstExpr, child.output).genCode(ctx)
+        BindReferences.bindReference(firstExpr, attributeSeq).genCode(ctx)
       } else {
         val isNull = ctx.freshName("isNull")
         val value = ctx.freshName("value")
@@ -170,7 +171,7 @@ case class ExpandExec(
       var updateCode = ""
       for (col <- exprs.indices) {
         if (!sameOutput(col)) {
-          val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx)
+          val ev = BindReferences.bindReference(exprs(col), attributeSeq).genCode(ctx)
           updateCode +=
             s"""
                |${ev.code}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
index 98c4a51299958..a1fb23d621d49 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala
@@ -77,6 +77,7 @@ abstract class AggregationIterator(
     val expressionsLength = expressions.length
     val functions = new Array[AggregateFunction](expressionsLength)
     var i = 0
+    val inputAttributeSeq: AttributeSeq = inputAttributes
     while (i < expressionsLength) {
       val func = expressions(i).aggregateFunction
       val funcWithBoundReferences: AggregateFunction = expressions(i).mode match {
@@ -86,7 +87,7 @@ abstract class AggregationIterator(
           // this function is Partial or Complete because we will call eval of this
           // function's children in the update method of this aggregate function.
           // Those eval calls require BoundReferences to work.
-          BindReferences.bindReference(func, inputAttributes)
+          BindReferences.bindReference(func, inputAttributeSeq)
         case _ =>
           // We only need to set inputBufferOffset for aggregate functions with mode
           // PartialMerge and Final.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 4827f838fc514..28801774418a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
@@ -199,15 +200,13 @@ case class HashAggregateExec(
     val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
       // evaluate aggregate results
       ctx.currentVars = bufVars
-      val aggResults = functions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
-      }
+      val aggResults = bindReferences(
+        functions.map(_.evaluateExpression),
+        aggregateBufferAttributes).map(_.genCode(ctx))
       val evaluateAggResults = evaluateVariables(aggResults)
       // evaluate result expressions
       ctx.currentVars = aggResults
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, aggregateAttributes).genCode(ctx)
-      }
+      val resultVars = bindReferences(resultExpressions, aggregateAttributes).map(_.genCode(ctx))
       (resultVars, s"""
         |$evaluateAggResults
         |${evaluateVariables(resultVars)}
@@ -264,7 +263,7 @@ case class HashAggregateExec(
       }
     }
     ctx.currentVars = bufVars ++ input
-    val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs))
+    val boundUpdateExpr = bindReferences(updateExpr, inputAttrs)
     val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
     val effectiveCodes = subExprs.codes.mkString("\n")
     val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -456,16 +455,16 @@ case class HashAggregateExec(
       val evaluateBufferVars = evaluateVariables(bufferVars)
       // evaluate the aggregation result
       ctx.currentVars = bufferVars
-      val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx)
-      }
+      val aggResults = bindReferences(
+        declFunctions.map(_.evaluateExpression),
+        aggregateBufferAttributes).map(_.genCode(ctx))
       val evaluateAggResults = evaluateVariables(aggResults)
       // generate the final result
       ctx.currentVars = keyVars ++ aggResults
       val inputAttrs = groupingAttributes ++ aggregateAttributes
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
-      }
+      val resultVars = bindReferences[Expression](
+        resultExpressions,
+        inputAttrs).map(_.genCode(ctx))
       s"""
        $evaluateKeyVars
        $evaluateBufferVars
@@ -494,9 +493,9 @@ case class HashAggregateExec(
 
       ctx.currentVars = keyVars ++ resultBufferVars
       val inputAttrs = resultExpressions.map(_.toAttribute)
-      val resultVars = resultExpressions.map { e =>
-        BindReferences.bindReference(e, inputAttrs).genCode(ctx)
-      }
+      val resultVars = bindReferences[Expression](
+        resultExpressions,
+        inputAttrs).map(_.genCode(ctx))
       s"""
        $evaluateKeyVars
        $evaluateResultBufferVars
@@ -506,9 +505,9 @@ case class HashAggregateExec(
       // generate result based on grouping key
       ctx.INPUT_ROW = keyTerm
       ctx.currentVars = null
-      val eval = resultExpressions.map{ e =>
-        BindReferences.bindReference(e, groupingAttributes).genCode(ctx)
-      }
+      val eval = bindReferences[Expression](
+        resultExpressions,
+        groupingAttributes).map(_.genCode(ctx))
       consume(ctx, eval)
     }
     ctx.addNewFunction(funcName,
@@ -730,9 +729,9 @@ case class HashAggregateExec(
   private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     // create grouping key
     val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
-      ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+      ctx, bindReferences[Expression](groupingExpressions, child.output))
     val fastRowKeys = ctx.generateExpressions(
-      groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+      bindReferences[Expression](groupingExpressions, child.output))
     val unsafeRowKeys = unsafeRowKeyCode.value
     val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
     val fastRowBuffer = ctx.freshName("fastAggBuffer")
@@ -825,7 +824,7 @@ case class HashAggregateExec(
 
     val updateRowInRegularHashMap: String = {
       ctx.INPUT_ROW = unsafeRowBuffer
-      val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+      val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
       val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
       val effectiveCodes = subExprs.codes.mkString("\n")
       val unsafeRowBufferEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
@@ -849,7 +848,7 @@ case class HashAggregateExec(
       if (isFastHashMapEnabled) {
         if (isVectorizedHashMapEnabled) {
           ctx.INPUT_ROW = fastRowBuffer
-          val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+          val boundUpdateExpr = bindReferences(updateExpr, inputAttr)
           val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
           val effectiveCodes = subExprs.codes.mkString("\n")
           val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 09effe087e195..7e87a150c0a4d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -24,6 +24,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
 import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -56,7 +57,7 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
   }
 
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
-    val exprs = projectList.map(x => BindReferences.bindReference[Expression](x, child.output))
+    val exprs = bindReferences[Expression](projectList, child.output)
     val resultVars = exprs.map(_.genCode(ctx))
     // Evaluation of non-deterministic expressions can't be deferred.
     val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
index 774fe38f5c2e6..260ad97506a85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.catalog.BucketSpec
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
 import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
 import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
@@ -145,9 +146,8 @@ object FileFormatWriter extends Logging {
         // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
         // the physical plan may have different attribute ids due to optimizer removing some
         // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
-        val orderingExpr = requiredOrdering
-          .map(SortOrder(_, Ascending))
-          .map(BindReferences.bindReference(_, outputSpec.outputColumns))
+        val orderingExpr = bindReferences(
+          requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
         SortExec(
           orderingExpr,
           global = false,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 1aef5f6864263..5ee4c7ffb1911 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.{RowIterator, SparkPlan}
@@ -63,9 +64,8 @@ trait HashJoin {
   protected lazy val (buildKeys, streamedKeys) = {
     require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
       "Join keys from two sides should have same types")
-    val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
-    val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
-      .map(BindReferences.bindReference(_, right.output))
+    val lkeys = bindReferences(HashJoin.rewriteKeyExpr(leftKeys), left.output)
+    val rkeys = bindReferences(HashJoin.rewriteKeyExpr(rightKeys), right.output)
     buildSide match {
       case BuildLeft => (lkeys, rkeys)
       case BuildRight => (rkeys, lkeys)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index d7d3f6d6078b4..f829f07e80720 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans._
@@ -393,7 +394,7 @@ case class SortMergeJoinExec(
       input: Seq[Attribute]): Seq[ExprCode] = {
     ctx.INPUT_ROW = row
     ctx.currentVars = null
-    keys.map(BindReferences.bindReference(_, input).genCode(ctx))
+    bindReferences(keys, input).map(_.genCode(ctx))
   }
 
   private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 156002ef58fbe..5bf34558fe493 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -21,6 +21,7 @@ import java.util
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
 import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
 
@@ -89,9 +90,8 @@ private[window] final class OffsetWindowFunctionFrame(
   private[this] val projection = {
     // Collect the expressions and bind them.
     val inputAttrs = inputSchema.map(_.withNullability(true))
-    val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
-      BindReferences.bindReference(e.input, inputAttrs)
-    }
+    val boundExpressions = Seq.fill(ordinal)(NoOp) ++ bindReferences(
+      expressions.toSeq.map(_.input), inputAttrs)
 
     // Create the projection.
     newMutableProjection(boundExpressions, Nil).target(target)
@@ -100,7 +100,7 @@ private[window] final class OffsetWindowFunctionFrame(
   /** Create the projection used when the offset row DOES NOT exists. */
   private[this] val fillDefaultValue = {
     // Collect the expressions and bind them.
-    val inputAttrs = inputSchema.map(_.withNullability(true))
+    val inputAttrs: AttributeSeq = inputSchema.map(_.withNullability(true))
     val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { e =>
       if (e.default == null || e.default.foldable && e.default.eval() == null) {
         // The default value is null.


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org