You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/11/02 18:41:53 UTC

spark git commit: [SPARK-14393][SQL] values generated by non-deterministic functions shouldn't change after coalesce or union

Repository: spark
Updated Branches:
  refs/heads/master 742e0fea5 -> 02f203107


[SPARK-14393][SQL] values generated by non-deterministic functions shouldn't change after coalesce or union

## What changes were proposed in this pull request?

When a user appended a column using a "nondeterministic" function to a DataFrame, e.g., `rand`, `randn`, and `monotonically_increasing_id`, the expected semantic is the following:
- The value in each row should remain unchanged, as if we materialize the column immediately, regardless of later DataFrame operations.

However, since we use `TaskContext.getPartitionId` to get the partition index from the current thread, the values from nondeterministic columns might change if we call `union` or `coalesce` after. `TaskContext.getPartitionId` returns the partition index of the current Spark task, which might not be the corresponding partition index of the DataFrame where we defined the column.

See the unit tests below or JIRA for examples.

This PR uses the partition index from `RDD.mapPartitionWithIndex` instead of `TaskContext` and fixes the partition initialization logic in whole-stage codegen, normal codegen, and codegen fallback. `initializeStatesForPartition(partitionIndex: Int)` was added to `Projection`, `Nondeterministic`, and `Predicate` (codegen) and initialized right after object creation in `mapPartitionWithIndex`. `newPredicate` now returns a `Predicate` instance rather than a function for proper initialization.
## How was this patch tested?

Unit tests. (Actually I'm not very confident that this PR fixed all issues without introducing new ones ...)

cc: rxin davies

Author: Xiangrui Meng <me...@databricks.com>

Closes #15567 from mengxr/SPARK-14393.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/02f20310
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/02f20310
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/02f20310

Branch: refs/heads/master
Commit: 02f203107b8eda1f1576e36c4f12b0e3bc5e910e
Parents: 742e0fe
Author: Xiangrui Meng <me...@databricks.com>
Authored: Wed Nov 2 11:41:49 2016 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Nov 2 11:41:49 2016 -0700

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/rdd/RDD.scala   | 16 +++++-
 .../sql/catalyst/expressions/Expression.scala   | 19 +++++--
 .../catalyst/expressions/InputFileName.scala    |  2 +-
 .../expressions/MonotonicallyIncreasingID.scala | 11 +++--
 .../sql/catalyst/expressions/Projection.scala   | 22 ++++++---
 .../catalyst/expressions/SparkPartitionID.scala | 13 +++--
 .../expressions/codegen/CodeGenerator.scala     | 14 ++++++
 .../expressions/codegen/CodegenFallback.scala   | 18 +++++--
 .../codegen/GenerateMutableProjection.scala     |  4 ++
 .../expressions/codegen/GeneratePredicate.scala | 18 +++++--
 .../codegen/GenerateSafeProjection.scala        |  4 ++
 .../codegen/GenerateUnsafeProjection.scala      |  4 ++
 .../sql/catalyst/expressions/package.scala      | 10 +++-
 .../sql/catalyst/expressions/predicates.scala   |  4 --
 .../expressions/randomExpressions.scala         | 14 +++---
 .../sql/catalyst/optimizer/Optimizer.scala      |  1 +
 .../expressions/ExpressionEvalHelper.scala      |  5 +-
 .../codegen/CodegenExpressionCachingSuite.scala | 13 +++--
 .../sql/execution/DataSourceScanExec.scala      |  6 ++-
 .../spark/sql/execution/ExistingRDD.scala       |  3 +-
 .../spark/sql/execution/GenerateExec.scala      |  3 +-
 .../apache/spark/sql/execution/SparkPlan.scala  |  4 +-
 .../sql/execution/WholeStageCodegenExec.scala   |  8 ++-
 .../sql/execution/basicPhysicalOperators.scala  |  8 +--
 .../columnar/InMemoryTableScanExec.scala        |  5 +-
 .../joins/BroadcastNestedLoopJoinExec.scala     |  7 +--
 .../execution/joins/CartesianProductExec.scala  |  8 +--
 .../spark/sql/execution/joins/HashJoin.scala    |  2 +-
 .../sql/execution/joins/SortMergeJoinExec.scala |  2 +-
 .../apache/spark/sql/execution/objects.scala    |  6 ++-
 .../spark/sql/DataFrameFunctionsSuite.scala     | 52 ++++++++++++++++++++
 .../sql/hive/execution/HiveTableScanExec.scala  |  3 +-
 32 files changed, 231 insertions(+), 78 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index db535de..e018af3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
-   * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
-   * performance API to be used carefully only if we are sure that the RDD elements are
+   * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning.
+   * It is a performance API to be used carefully only if we are sure that the RDD elements are
    * serializable and don't require closure cleaning.
    *
    * @param preservesPartitioning indicates whether the input function preserves the partitioner,
    * which should be `false` unless this is a pair RDD and the input function doesn't modify
    * the keys.
    */
+  private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
+      f: (Int, Iterator[T]) => Iterator[U],
+      preservesPartitioning: Boolean = false): RDD[U] = withScope {
+    new MapPartitionsRDD(
+      this,
+      (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
+      preservesPartitioning)
+  }
+
+  /**
+   * [performance] Spark's internal mapPartitions method that skips closure cleaning.
+   */
   private[spark] def mapPartitionsInternal[U: ClassTag](
       f: Iterator[T] => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = withScope {

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 9edc1ce..726a231 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -272,17 +272,28 @@ trait Nondeterministic extends Expression {
   final override def deterministic: Boolean = false
   final override def foldable: Boolean = false
 
+  @transient
   private[this] var initialized = false
 
-  final def setInitialValues(): Unit = {
-    initInternal()
+  /**
+   * Initializes internal states given the current partition index and mark this as initialized.
+   * Subclasses should override [[initializeInternal()]].
+   */
+  final def initialize(partitionIndex: Int): Unit = {
+    initializeInternal(partitionIndex)
     initialized = true
   }
 
-  protected def initInternal(): Unit
+  protected def initializeInternal(partitionIndex: Int): Unit
 
+  /**
+   * @inheritdoc
+   * Throws an exception if [[initialize()]] is not called yet.
+   * Subclasses should override [[evalInternal()]].
+   */
   final override def eval(input: InternalRow = null): Any = {
-    require(initialized, "nondeterministic expression should be initialized before evaluate")
+    require(initialized,
+      s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.")
     evalInternal(input)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index 96929ec..b6c12c5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
 
   override def prettyName: String = "input_file_name"
 
-  override protected def initInternal(): Unit = {}
+  override protected def initializeInternal(partitionIndex: Int): Unit = {}
 
   override protected def evalInternal(input: InternalRow): UTF8String = {
     InputFileNameHolder.getInputFileName()

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 5b4922e..72b8dcc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
 
   @transient private[this] var partitionMask: Long = _
 
-  override protected def initInternal(): Unit = {
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
     count = 0L
-    partitionMask = TaskContext.getPartitionId().toLong << 33
+    partitionMask = partitionIndex.toLong << 33
   }
 
   override def nullable: Boolean = false
@@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val countTerm = ctx.freshName("count")
     val partitionMaskTerm = ctx.freshName("partitionMask")
-    ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
-    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
-      s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")
+    ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
+    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
+    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
+    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
 
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 03e054d..476e37e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
+ *
  * @param expressions a sequence of expressions that determine the value of each column of the
  *                    output row.
  */
@@ -30,10 +31,12 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
     this(expressions.map(BindReferences.bindReference(_, inputSchema)))
 
-  expressions.foreach(_.foreach {
-    case n: Nondeterministic => n.setInitialValues()
-    case _ =>
-  })
+  override def initialize(partitionIndex: Int): Unit = {
+    expressions.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
 
   // null check is required for when Kryo invokes the no-arg constructor.
   protected val exprArray = if (expressions != null) expressions.toArray else null
@@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
 /**
  * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
  * expressions.
+ *
  * @param expressions a sequence of expressions that determine the value of each column of the
  *                    output row.
  */
@@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
 
   private[this] val buffer = new Array[Any](expressions.size)
 
-  expressions.foreach(_.foreach {
-    case n: Nondeterministic => n.setInitialValues()
-    case _ =>
-  })
+  override def initialize(partitionIndex: Int): Unit = {
+    expressions.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
 
   private[this] val exprArray = expressions.toArray
   private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length)

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 1f675d5..6bef473 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -17,16 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.TaskContext
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.types.{DataType, IntegerType}
 
 /**
- * Expression that returns the current partition id of the Spark task.
+ * Expression that returns the current partition id.
  */
 @ExpressionDescription(
-  usage = "_FUNC_() - Returns the current partition id of the Spark task",
+  usage = "_FUNC_() - Returns the current partition id",
   extended = "> SELECT _FUNC_();\n 0")
 case class SparkPartitionID() extends LeafExpression with Nondeterministic {
 
@@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
 
   override val prettyName = "SPARK_PARTITION_ID"
 
-  override protected def initInternal(): Unit = {
-    partitionId = TaskContext.getPartitionId()
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
+    partitionId = partitionIndex
   }
 
   override protected def evalInternal(input: InternalRow): Int = partitionId
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val idTerm = ctx.freshName("partitionId")
-    ctx.addMutableState(ctx.JAVA_INT, idTerm,
-      s"$idTerm = org.apache.spark.TaskContext.getPartitionId();")
+    ctx.addMutableState(ctx.JAVA_INT, idTerm, "")
+    ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
     ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 6cab50a..9c3c6d3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -185,6 +185,20 @@ class CodegenContext {
   }
 
   /**
+   * Code statements to initialize states that depend on the partition index.
+   * An integer `partitionIndex` will be made available within the scope.
+   */
+  val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty
+
+  def addPartitionInitializationStatement(statement: String): Unit = {
+    partitionInitializationStatements += statement
+  }
+
+  def initPartition(): String = {
+    partitionInitializationStatements.mkString("\n")
+  }
+
+  /**
    * Holding all the functions those will be added into generated class.
    */
   val addedFunctions: mutable.Map[String, String] =

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 6a5a3e7..0322d1d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No
 trait CodegenFallback extends Expression {
 
   protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    foreach {
-      case n: Nondeterministic => n.setInitialValues()
-      case _ =>
-    }
-
     // LeafNode does not need `input`
     val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
     val idx = ctx.references.length
     ctx.references += this
+    var childIndex = idx
+    this.foreach {
+      case n: Nondeterministic =>
+        // This might add the current expression twice, but it won't hurt.
+        ctx.references += n
+        childIndex += 1
+        ctx.addPartitionInitializationStatement(
+          s"""
+             |((Nondeterministic) references[$childIndex])
+             |  .initialize(partitionIndex);
+          """.stripMargin)
+      case _ =>
+    }
     val objectTerm = ctx.freshName("obj")
     val placeHolder = ctx.registerComment(this.toString)
     if (nullable) {

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 5c4b56b..4d73244 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public ${classOf[BaseMutableProjection].getName} target(InternalRow row) {

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 39aa7b1..dcd1ed9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._
  */
 abstract class Predicate {
   def eval(r: InternalRow): Boolean
+
+  /**
+   * Initializes internal states given the current partition index.
+   * This is used by nondeterministic expressions to set initial states.
+   * The default implementation does nothing.
+   */
+  def initialize(partitionIndex: Int): Unit = {}
 }
 
 /**
  * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
  */
-object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] {
+object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
 
   protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
 
   protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
     BindReferences.bindReference(in, inputSchema)
 
-  protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
+  protected def create(predicate: Expression): Predicate = {
     val ctx = newCodeGenContext()
     val eval = predicate.genCode(ctx)
 
@@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
@@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
       new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
     logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
 
-    val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
-    (r: InternalRow) => p.eval(r)
+    CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 2773e1a..b1cb6ed 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public java.lang.Object apply(java.lang.Object _i) {

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 7cc4537..7e4c908 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -380,6 +380,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         // Scala.Function1 need this

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 1510a47..1b00c9e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -64,7 +64,15 @@ package object expressions  {
    * column of the new row. If the schema of the input row is specified, then the given expression
    * will be bound to that schema.
    */
-  abstract class Projection extends (InternalRow => InternalRow)
+  abstract class Projection extends (InternalRow => InternalRow) {
+
+    /**
+     * Initializes internal states given the current partition index.
+     * This is used by nondeterministic expressions to set initial states.
+     * The default implementation does nothing.
+     */
+    def initialize(partitionIndex: Int): Unit = {}
+  }
 
   /**
    * Converts a [[InternalRow]] to another Row given a sequence of expression that define each

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 9394e39..c941a57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -31,10 +31,6 @@ object InterpretedPredicate {
     create(BindReferences.bindReference(expression, inputSchema))
 
   def create(expression: Expression): (InternalRow => Boolean) = {
-    expression.foreach {
-      case n: Nondeterministic => n.setInitialValues()
-      case _ =>
-    }
     (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index ca20076..e09029f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -42,8 +42,8 @@ abstract class RDG extends LeafExpression with Nondeterministic {
    */
   @transient protected var rng: XORShiftRandom = _
 
-  override protected def initInternal(): Unit = {
-    rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
+    rng = new XORShiftRandom(seed + partitionIndex)
   }
 
   override def nullable: Boolean = false
@@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
     val className = classOf[XORShiftRandom].getName
-    ctx.addMutableState(className, rngTerm,
-      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
+    ctx.addMutableState(className, rngTerm, "")
+    ctx.addPartitionInitializationStatement(
+      s"$rngTerm = new $className(${seed}L + partitionIndex);")
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
   }
@@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
     val className = classOf[XORShiftRandom].getName
-    ctx.addMutableState(className, rngTerm,
-      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
+    ctx.addMutableState(className, rngTerm, "")
+    ctx.addPartitionInitializationStatement(
+      s"$rngTerm = new $className(${seed}L + partitionIndex);")
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e5e2cd7..b6ad5db 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
     case Project(projectList, LocalRelation(output, data))
         if !projectList.exists(hasUnevaluableExpr) =>
       val projection = new InterpretedProjection(projectList, output)
+      projection.initialize(0)
       LocalRelation(projectList.map(_.toAttribute), data.map(projection))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index f0c149c..9ceb709 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
 
   protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
     expression.foreach {
-      case n: Nondeterministic => n.setInitialValues()
+      case n: Nondeterministic => n.initialize(0)
       case _ =>
     }
     expression.eval(inputRow)
@@ -121,6 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     val plan = generateProject(
       GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
 
     val actual = plan(inputRow).get(0, expression.dataType)
     if (!checkResult(actual, expected)) {
@@ -182,12 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     var plan = generateProject(
       GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
     var actual = plan(inputRow).get(0, expression.dataType)
     assert(checkResult(actual, expected))
 
     plan = generateProject(
       GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
     actual = FromUnsafeProjection(expression.dataType :: Nil)(
       plan(inputRow)).get(0, expression.dataType)
     assert(checkResult(actual, expected))

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
index 06dc3bd..fe5cb8e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
@@ -31,19 +31,22 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
     // Use an Add to wrap two of them together in case we only initialize the top level expressions.
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = UnsafeProjection.create(Seq(expr))
+    instance.initialize(0)
     assert(instance.apply(null).getBoolean(0) === false)
   }
 
   test("GenerateMutableProjection should initialize expressions") {
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = GenerateMutableProjection.generate(Seq(expr))
+    instance.initialize(0)
     assert(instance.apply(null).getBoolean(0) === false)
   }
 
   test("GeneratePredicate should initialize expressions") {
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = GeneratePredicate.generate(expr)
-    assert(instance.apply(null) === false)
+    instance.initialize(0)
+    assert(instance.eval(null) === false)
   }
 
   test("GenerateUnsafeProjection should not share expression instances") {
@@ -73,13 +76,13 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
   test("GeneratePredicate should not share expression instances") {
     val expr1 = MutableExpression()
     val instance1 = GeneratePredicate.generate(expr1)
-    assert(instance1.apply(null) === false)
+    assert(instance1.eval(null) === false)
 
     val expr2 = MutableExpression()
     expr2.mutableState = true
     val instance2 = GeneratePredicate.generate(expr2)
-    assert(instance1.apply(null) === false)
-    assert(instance2.apply(null) === true)
+    assert(instance1.eval(null) === false)
+    assert(instance2.eval(null) === true)
   }
 
 }
@@ -89,7 +92,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
  */
 case class NondeterministicExpression()
   extends LeafExpression with Nondeterministic with CodegenFallback {
-  override protected def initInternal(): Unit = { }
+  override protected def initializeInternal(partitionIndex: Int): Unit = {}
   override protected def evalInternal(input: InternalRow): Any = false
   override def nullable: Boolean = false
   override def dataType: DataType = BooleanType

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index fdd1fa3..e485b52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -71,8 +71,9 @@ case class RowDataSourceScanExec(
     val unsafeRow = if (outputUnsafeRows) {
       rdd
     } else {
-      rdd.mapPartitionsInternal { iter =>
+      rdd.mapPartitionsWithIndexInternal { (index, iter) =>
         val proj = UnsafeProjection.create(schema)
+        proj.initialize(index)
         iter.map(proj)
       }
     }
@@ -284,8 +285,9 @@ case class FileSourceScanExec(
       val unsafeRows = {
         val scan = inputRDD
         if (needsUnsafeRowConversion) {
-          scan.mapPartitionsInternal { iter =>
+          scan.mapPartitionsWithIndexInternal { (index, iter) =>
             val proj = UnsafeProjection.create(schema)
+            proj.initialize(index)
             iter.map(proj)
           }
         } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 455fb5b..aab087c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -190,8 +190,9 @@ case class RDDScanExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-    rdd.mapPartitionsInternal { iter =>
+    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(schema)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 2663129..19fbf0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -94,8 +94,9 @@ case class GenerateExec(
     }
 
     val numOutputRows = longMetric("numOutputRows")
-    rows.mapPartitionsInternal { iter =>
+    rows.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(output, output)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 24d0cff..cadab37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -29,7 +29,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -354,7 +354,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   }
 
   protected def newPredicate(
-      expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
+      expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
     GeneratePredicate.generate(expression, inputSchema)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6303483..516b9d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -331,6 +331,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
           partitionIndex = index;
           this.inputs = inputs;
           ${ctx.initMutableStates()}
+          ${ctx.initPartition()}
         }
 
         ${ctx.declareAddedFunctions()}
@@ -383,10 +384,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
     } else {
       // Right now, we support up to two input RDDs.
       rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
-        val partitionIndex = TaskContext.getPartitionId()
+        Iterator((leftIter, rightIter))
+        // a small hack to obtain the correct partition index
+      }.mapPartitionsWithIndex { (index, zippedIter) =>
+        val (leftIter, rightIter) = zippedIter.next()
         val clazz = CodeGenerator.compile(cleanedSource)
         val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
-        buffer.init(partitionIndex, Array(leftIter, rightIter))
+        buffer.init(index, Array(leftIter, rightIter))
         new Iterator[InternalRow] {
           override def hasNext: Boolean = {
             val v = buffer.hasNext

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index a5291e0..32133f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -70,9 +70,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val project = UnsafeProjection.create(projectList, child.output,
         subexpressionEliminationEnabled)
+      project.initialize(index)
       iter.map(project)
     }
   }
@@ -205,10 +206,11 @@ case class FilterExec(condition: Expression, child: SparkPlan)
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val predicate = newPredicate(condition, child.output)
+      predicate.initialize(0)
       iter.filter { row =>
-        val r = predicate(row)
+        val r = predicate.eval(row)
         if (r) numOutputRows += 1
         r
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index b87016d..9028caa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -132,10 +132,11 @@ case class InMemoryTableScanExec(
     val relOutput: AttributeSeq = relation.output
     val buffers = relation.cachedColumnBuffers
 
-    buffers.mapPartitionsInternal { cachedBatchIterator =>
+    buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
       val partitionFilter = newPredicate(
         partitionFilters.reduceOption(And).getOrElse(Literal(true)),
         schema)
+      partitionFilter.initialize(index)
 
       // Find the ordinals and data types of the requested columns.
       val (requestedColumnIndices, requestedColumnDataTypes) =
@@ -147,7 +148,7 @@ case class InMemoryTableScanExec(
       val cachedBatchesToScan =
         if (inMemoryPartitionPruningEnabled) {
           cachedBatchIterator.filter { cachedBatch =>
-            if (!partitionFilter(cachedBatch.stats)) {
+            if (!partitionFilter.eval(cachedBatch.stats)) {
               def statsString: String = schemaIndex.map {
                 case (a, i) =>
                   val value = cachedBatch.stats.get(i, a.dataType)

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index bfe7e3d..f526a19 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -52,7 +52,7 @@ case class BroadcastNestedLoopJoinExec(
       UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
   }
 
-  private[this] def genResultProjection: InternalRow => InternalRow = joinType match {
+  private[this] def genResultProjection: UnsafeProjection = joinType match {
     case LeftExistence(j) =>
       UnsafeProjection.create(output, output)
     case other =>
@@ -84,7 +84,7 @@ case class BroadcastNestedLoopJoinExec(
 
   @transient private lazy val boundCondition = {
     if (condition.isDefined) {
-      newPredicate(condition.get, streamed.output ++ broadcast.output)
+      newPredicate(condition.get, streamed.output ++ broadcast.output).eval _
     } else {
       (r: InternalRow) => true
     }
@@ -366,8 +366,9 @@ case class BroadcastNestedLoopJoinExec(
     }
 
     val numOutputRows = longMetric("numOutputRows")
-    resultRdd.mapPartitionsInternal { iter =>
+    resultRdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val resultProj = genResultProjection
+      resultProj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         resultProj(r)

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 15dc9b4..8341fe2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -98,15 +98,15 @@ case class CartesianProductExec(
     val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
 
     val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
-    pair.mapPartitionsInternal { iter =>
+    pair.mapPartitionsWithIndexInternal { (index, iter) =>
       val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
       val filtered = if (condition.isDefined) {
-        val boundCondition: (InternalRow) => Boolean =
-          newPredicate(condition.get, left.output ++ right.output)
+        val boundCondition = newPredicate(condition.get, left.output ++ right.output)
+        boundCondition.initialize(index)
         val joined = new JoinedRow
 
         iter.filter { r =>
-          boundCondition(joined(r._1, r._2))
+          boundCondition.eval(joined(r._1, r._2))
         }
       } else {
         iter

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 05c5e2f..1aef5f6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -81,7 +81,7 @@ trait HashJoin {
     UnsafeProjection.create(streamedKeys)
 
   @transient private[this] lazy val boundCondition = if (condition.isDefined) {
-    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
+    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _
   } else {
     (r: InternalRow) => true
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ecf7cf2..ca9c0ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -101,7 +101,7 @@ case class SortMergeJoinExec(
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
       val boundCondition: (InternalRow) => Boolean = {
         condition.map { cond =>
-          newPredicate(cond, left.output ++ right.output)
+          newPredicate(cond, left.output ++ right.output).eval _
         }.getOrElse {
           (r: InternalRow) => true
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 9df56bb..fde3b2a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -87,8 +87,9 @@ case class DeserializeToObjectExec(
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
+      projection.initialize(index)
       iter.map(projection)
     }
   }
@@ -124,8 +125,9 @@ case class SerializeFromObjectExec(
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val projection = UnsafeProjection.create(serializer)
+      projection.initialize(index)
       iter.map(projection)
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 586a0ff..0e9a2c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -19,7 +19,13 @@ package org.apache.spark.sql
 
 import java.nio.charset.StandardCharsets
 
+import scala.util.Random
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
@@ -406,4 +412,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       Seq(Row(true), Row(true))
     )
   }
+
+  private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
+    import DataFrameFunctionsSuite.CodegenFallbackExpr
+    for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
+      val c = if (codegenFallback) {
+        Column(CodegenFallbackExpr(v.expr))
+      } else {
+        v
+      }
+      withSQLConf(
+        (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString),
+        (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) {
+        val df = spark.range(0, 4, 1, 4).withColumn("c", c)
+        val rows = df.collect()
+        val rowsAfterCoalesce = df.coalesce(2).collect()
+        assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " +
+          s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.")
+
+        val df1 = spark.range(0, 2, 1, 2).withColumn("c", c)
+        val rows1 = df1.collect()
+        val df2 = spark.range(2, 4, 1, 2).withColumn("c", c)
+        val rows2 = df2.collect()
+        val rowsAfterUnion = df1.union(df2).collect()
+        assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " +
+          s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.")
+      }
+    }
+  }
+
+  test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " +
+    "coalesce or union") {
+    Seq(
+      monotonically_increasing_id(), spark_partition_id(),
+      rand(Random.nextLong()), randn(Random.nextLong())
+    ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
+  }
+}
+
+object DataFrameFunctionsSuite {
+  case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback {
+    override def children: Seq[Expression] = Seq(child)
+    override def nullable: Boolean = child.nullable
+    override def dataType: DataType = child.dataType
+    override lazy val resolved = true
+    override def eval(input: InternalRow): Any = child.eval(input)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/02f20310/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 231f204..c80695b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -154,8 +154,9 @@ case class HiveTableScanExec(
     val numOutputRows = longMetric("numOutputRows")
     // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649)
     val outputSchema = schema
-    rdd.mapPartitionsInternal { iter =>
+    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(outputSchema)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org