From 0093257ea94d3a197ca061b54c04685d7c1f616a Mon Sep 17 00:00:00 2001
From: Xiangrui Meng <meng@databricks.com>
Date: Wed, 2 Nov 2016 11:41:49 -0700
Subject: [PATCH] [SPARK-14393][SQL] values generated by non-deterministic
 functions shouldn't change after coalesce or union

## What changes were proposed in this pull request?

When a user appended a column using a "nondeterministic" function to a DataFrame, e.g., `rand`, `randn`, and `monotonically_increasing_id`, the expected semantic is the following:
- The value in each row should remain unchanged, as if we materialize the column immediately, regardless of later DataFrame operations.

However, since we use `TaskContext.getPartitionId` to get the partition index from the current thread, the values from nondeterministic columns might change if we call `union` or `coalesce` after. `TaskContext.getPartitionId` returns the partition index of the current Spark task, which might not be the corresponding partition index of the DataFrame where we defined the column.

See the unit tests below or JIRA for examples.

This PR uses the partition index from `RDD.mapPartitionWithIndex` instead of `TaskContext` and fixes the partition initialization logic in whole-stage codegen, normal codegen, and codegen fallback. `initializeStatesForPartition(partitionIndex: Int)` was added to `Projection`, `Nondeterministic`, and `Predicate` (codegen) and initialized right after object creation in `mapPartitionWithIndex`. `newPredicate` now returns a `Predicate` instance rather than a function for proper initialization.
## How was this patch tested?

Unit tests. (Actually I'm not very confident that this PR fixed all issues without introducing new ones ...)

cc: rxin davies

Author: Xiangrui Meng <meng@databricks.com>

Closes #15567 from mengxr/SPARK-14393.

(cherry picked from commit 02f203107b8eda1f1576e36c4f12b0e3bc5e910e)
Signed-off-by: Reynold Xin <rxin@databricks.com>
---
 .../main/scala/org/apache/spark/rdd/RDD.scala | 16 +++++-
 .../sql/catalyst/expressions/Expression.scala | 19 +++++--
 .../catalyst/expressions/InputFileName.scala  |  2 +-
 .../MonotonicallyIncreasingID.scala           | 11 ++--
 .../sql/catalyst/expressions/Projection.scala | 22 +++++---
 .../expressions/SparkPartitionID.scala        | 13 +++--
 .../expressions/codegen/CodeGenerator.scala   | 14 +++++
 .../expressions/codegen/CodegenFallback.scala | 18 +++++--
 .../codegen/GenerateMutableProjection.scala   |  4 ++
 .../codegen/GeneratePredicate.scala           | 18 +++++--
 .../codegen/GenerateSafeProjection.scala      |  4 ++
 .../codegen/GenerateUnsafeProjection.scala    |  4 ++
 .../sql/catalyst/expressions/package.scala    | 10 +++-
 .../sql/catalyst/expressions/predicates.scala |  4 --
 .../expressions/randomExpressions.scala       | 14 ++---
 .../sql/catalyst/optimizer/Optimizer.scala    |  1 +
 .../expressions/ExpressionEvalHelper.scala    |  5 +-
 .../CodegenExpressionCachingSuite.scala       | 13 +++--
 .../sql/execution/DataSourceScanExec.scala    |  6 ++-
 .../spark/sql/execution/ExistingRDD.scala     |  3 +-
 .../spark/sql/execution/GenerateExec.scala    |  3 +-
 .../spark/sql/execution/SparkPlan.scala       |  4 +-
 .../sql/execution/WholeStageCodegenExec.scala |  8 ++-
 .../execution/basicPhysicalOperators.scala    |  8 +--
 .../columnar/InMemoryTableScanExec.scala      |  5 +-
 .../joins/BroadcastNestedLoopJoinExec.scala   |  7 +--
 .../joins/CartesianProductExec.scala          |  8 +--
 .../spark/sql/execution/joins/HashJoin.scala  |  2 +-
 .../execution/joins/SortMergeJoinExec.scala   |  2 +-
 .../apache/spark/sql/execution/objects.scala  |  6 ++-
 .../spark/sql/DataFrameFunctionsSuite.scala   | 52 +++++++++++++++++++
 .../hive/execution/HiveTableScanExec.scala    |  3 +-
 32 files changed, 231 insertions(+), 78 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index db535de9e9..e018af35cb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
-   * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
-   * performance API to be used carefully only if we are sure that the RDD elements are
+   * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning.
+   * It is a performance API to be used carefully only if we are sure that the RDD elements are
    * serializable and don't require closure cleaning.
    *
    * @param preservesPartitioning indicates whether the input function preserves the partitioner,
    * which should be `false` unless this is a pair RDD and the input function doesn't modify
    * the keys.
    */
+  private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
+      f: (Int, Iterator[T]) => Iterator[U],
+      preservesPartitioning: Boolean = false): RDD[U] = withScope {
+    new MapPartitionsRDD(
+      this,
+      (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
+      preservesPartitioning)
+  }
+
+  /**
+   * [performance] Spark's internal mapPartitions method that skips closure cleaning.
+   */
   private[spark] def mapPartitionsInternal[U: ClassTag](
       f: Iterator[T] => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = withScope {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 9edc1ceff2..726a231fd8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -272,17 +272,28 @@ trait Nondeterministic extends Expression {
   final override def deterministic: Boolean = false
   final override def foldable: Boolean = false
 
+  @transient
   private[this] var initialized = false
 
-  final def setInitialValues(): Unit = {
-    initInternal()
+  /**
+   * Initializes internal states given the current partition index and mark this as initialized.
+   * Subclasses should override [[initializeInternal()]].
+   */
+  final def initialize(partitionIndex: Int): Unit = {
+    initializeInternal(partitionIndex)
     initialized = true
   }
 
-  protected def initInternal(): Unit
+  protected def initializeInternal(partitionIndex: Int): Unit
 
+  /**
+   * @inheritdoc
+   * Throws an exception if [[initialize()]] is not called yet.
+   * Subclasses should override [[evalInternal()]].
+   */
   final override def eval(input: InternalRow = null): Any = {
-    require(initialized, "nondeterministic expression should be initialized before evaluate")
+    require(initialized,
+      s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.")
     evalInternal(input)
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index 96929ecf56..b6c12c5351 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
 
   override def prettyName: String = "input_file_name"
 
-  override protected def initInternal(): Unit = {}
+  override protected def initializeInternal(partitionIndex: Int): Unit = {}
 
   override protected def evalInternal(input: InternalRow): UTF8String = {
     InputFileNameHolder.getInputFileName()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 5b4922e0cf..72b8dcca26 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
 
   @transient private[this] var partitionMask: Long = _
 
-  override protected def initInternal(): Unit = {
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
     count = 0L
-    partitionMask = TaskContext.getPartitionId().toLong << 33
+    partitionMask = partitionIndex.toLong << 33
   }
 
   override def nullable: Boolean = false
@@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val countTerm = ctx.freshName("count")
     val partitionMaskTerm = ctx.freshName("partitionMask")
-    ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
-    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
-      s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")
+    ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
+    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
+    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
+    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
 
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 03e054d098..476e37e6a9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
+ *
  * @param expressions a sequence of expressions that determine the value of each column of the
  *                    output row.
  */
@@ -30,10 +31,12 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
     this(expressions.map(BindReferences.bindReference(_, inputSchema)))
 
-  expressions.foreach(_.foreach {
-    case n: Nondeterministic => n.setInitialValues()
-    case _ =>
-  })
+  override def initialize(partitionIndex: Int): Unit = {
+    expressions.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
 
   // null check is required for when Kryo invokes the no-arg constructor.
   protected val exprArray = if (expressions != null) expressions.toArray else null
@@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
 /**
  * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
  * expressions.
+ *
  * @param expressions a sequence of expressions that determine the value of each column of the
  *                    output row.
  */
@@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
 
   private[this] val buffer = new Array[Any](expressions.size)
 
-  expressions.foreach(_.foreach {
-    case n: Nondeterministic => n.setInitialValues()
-    case _ =>
-  })
+  override def initialize(partitionIndex: Int): Unit = {
+    expressions.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
 
   private[this] val exprArray = expressions.toArray
   private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 1f675d5b07..6bef473cac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -17,16 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.TaskContext
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.types.{DataType, IntegerType}
 
 /**
- * Expression that returns the current partition id of the Spark task.
+ * Expression that returns the current partition id.
  */
 @ExpressionDescription(
-  usage = "_FUNC_() - Returns the current partition id of the Spark task",
+  usage = "_FUNC_() - Returns the current partition id",
   extended = "> SELECT _FUNC_();\n 0")
 case class SparkPartitionID() extends LeafExpression with Nondeterministic {
 
@@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
 
   override val prettyName = "SPARK_PARTITION_ID"
 
-  override protected def initInternal(): Unit = {
-    partitionId = TaskContext.getPartitionId()
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
+    partitionId = partitionIndex
   }
 
   override protected def evalInternal(input: InternalRow): Int = partitionId
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val idTerm = ctx.freshName("partitionId")
-    ctx.addMutableState(ctx.JAVA_INT, idTerm,
-      s"$idTerm = org.apache.spark.TaskContext.getPartitionId();")
+    ctx.addMutableState(ctx.JAVA_INT, idTerm, "")
+    ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
     ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 6cab50ae1b..9c3c6d3b2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -184,6 +184,20 @@ class CodegenContext {
     splitExpressions(initCodes, "init", Nil)
   }
 
+  /**
+   * Code statements to initialize states that depend on the partition index.
+   * An integer `partitionIndex` will be made available within the scope.
+   */
+  val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty
+
+  def addPartitionInitializationStatement(statement: String): Unit = {
+    partitionInitializationStatements += statement
+  }
+
+  def initPartition(): String = {
+    partitionInitializationStatements.mkString("\n")
+  }
+
   /**
    * Holding all the functions those will be added into generated class.
    */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 6a5a3e7933..0322d1dd6a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No
 trait CodegenFallback extends Expression {
 
   protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    foreach {
-      case n: Nondeterministic => n.setInitialValues()
-      case _ =>
-    }
-
     // LeafNode does not need `input`
     val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
     val idx = ctx.references.length
     ctx.references += this
+    var childIndex = idx
+    this.foreach {
+      case n: Nondeterministic =>
+        // This might add the current expression twice, but it won't hurt.
+        ctx.references += n
+        childIndex += 1
+        ctx.addPartitionInitializationStatement(
+          s"""
+             |((Nondeterministic) references[$childIndex])
+             |  .initialize(partitionIndex);
+          """.stripMargin)
+      case _ =>
+    }
     val objectTerm = ctx.freshName("obj")
     val placeHolder = ctx.registerComment(this.toString)
     if (nullable) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 5c4b56b0b2..4d73244554 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public ${classOf[BaseMutableProjection].getName} target(InternalRow row) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 39aa7b17de..dcd1ed96a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._
  */
 abstract class Predicate {
   def eval(r: InternalRow): Boolean
+
+  /**
+   * Initializes internal states given the current partition index.
+   * This is used by nondeterministic expressions to set initial states.
+   * The default implementation does nothing.
+   */
+  def initialize(partitionIndex: Int): Unit = {}
 }
 
 /**
  * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
  */
-object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] {
+object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
 
   protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
 
   protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
     BindReferences.bindReference(in, inputSchema)
 
-  protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
+  protected def create(predicate: Expression): Predicate = {
     val ctx = newCodeGenContext()
     val eval = predicate.genCode(ctx)
 
@@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
@@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
       new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
     logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
 
-    val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
-    (r: InternalRow) => p.eval(r)
+    CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 2773e1a666..b1cb6edefb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public java.lang.Object apply(java.lang.Object _i) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 7cc45372da..7e4c9089a2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -380,6 +380,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         // Scala.Function1 need this
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 1510a47966..1b00c9e79d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -64,7 +64,15 @@ package object expressions  {
    * column of the new row. If the schema of the input row is specified, then the given expression
    * will be bound to that schema.
    */
-  abstract class Projection extends (InternalRow => InternalRow)
+  abstract class Projection extends (InternalRow => InternalRow) {
+
+    /**
+     * Initializes internal states given the current partition index.
+     * This is used by nondeterministic expressions to set initial states.
+     * The default implementation does nothing.
+     */
+    def initialize(partitionIndex: Int): Unit = {}
+  }
 
   /**
    * Converts a [[InternalRow]] to another Row given a sequence of expression that define each
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 9394e39aad..c941a576d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -31,10 +31,6 @@ object InterpretedPredicate {
     create(BindReferences.bindReference(expression, inputSchema))
 
   def create(expression: Expression): (InternalRow => Boolean) = {
-    expression.foreach {
-      case n: Nondeterministic => n.setInitialValues()
-      case _ =>
-    }
     (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index ca200768b2..e09029f5aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -42,8 +42,8 @@ abstract class RDG extends LeafExpression with Nondeterministic {
    */
   @transient protected var rng: XORShiftRandom = _
 
-  override protected def initInternal(): Unit = {
-    rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
+    rng = new XORShiftRandom(seed + partitionIndex)
   }
 
   override def nullable: Boolean = false
@@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
     val className = classOf[XORShiftRandom].getName
-    ctx.addMutableState(className, rngTerm,
-      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
+    ctx.addMutableState(className, rngTerm, "")
+    ctx.addPartitionInitializationStatement(
+      s"$rngTerm = new $className(${seed}L + partitionIndex);")
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
   }
@@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
     val className = classOf[XORShiftRandom].getName
-    ctx.addMutableState(className, rngTerm,
-      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
+    ctx.addMutableState(className, rngTerm, "")
+    ctx.addPartitionInitializationStatement(
+      s"$rngTerm = new $className(${seed}L + partitionIndex);")
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e5e2cd7d27..b6ad5db74e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
     case Project(projectList, LocalRelation(output, data))
         if !projectList.exists(hasUnevaluableExpr) =>
       val projection = new InterpretedProjection(projectList, output)
+      projection.initialize(0)
       LocalRelation(projectList.map(_.toAttribute), data.map(projection))
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index f0c149c02b..9ceb709185 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
 
   protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
     expression.foreach {
-      case n: Nondeterministic => n.setInitialValues()
+      case n: Nondeterministic => n.initialize(0)
       case _ =>
     }
     expression.eval(inputRow)
@@ -121,6 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     val plan = generateProject(
       GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
 
     val actual = plan(inputRow).get(0, expression.dataType)
     if (!checkResult(actual, expected)) {
@@ -182,12 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     var plan = generateProject(
       GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
     var actual = plan(inputRow).get(0, expression.dataType)
     assert(checkResult(actual, expected))
 
     plan = generateProject(
       GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
     actual = FromUnsafeProjection(expression.dataType :: Nil)(
       plan(inputRow)).get(0, expression.dataType)
     assert(checkResult(actual, expected))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
index 06dc3bd33b..fe5cb8eda8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
@@ -31,19 +31,22 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
     // Use an Add to wrap two of them together in case we only initialize the top level expressions.
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = UnsafeProjection.create(Seq(expr))
+    instance.initialize(0)
     assert(instance.apply(null).getBoolean(0) === false)
   }
 
   test("GenerateMutableProjection should initialize expressions") {
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = GenerateMutableProjection.generate(Seq(expr))
+    instance.initialize(0)
     assert(instance.apply(null).getBoolean(0) === false)
   }
 
   test("GeneratePredicate should initialize expressions") {
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = GeneratePredicate.generate(expr)
-    assert(instance.apply(null) === false)
+    instance.initialize(0)
+    assert(instance.eval(null) === false)
   }
 
   test("GenerateUnsafeProjection should not share expression instances") {
@@ -73,13 +76,13 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
   test("GeneratePredicate should not share expression instances") {
     val expr1 = MutableExpression()
     val instance1 = GeneratePredicate.generate(expr1)
-    assert(instance1.apply(null) === false)
+    assert(instance1.eval(null) === false)
 
     val expr2 = MutableExpression()
     expr2.mutableState = true
     val instance2 = GeneratePredicate.generate(expr2)
-    assert(instance1.apply(null) === false)
-    assert(instance2.apply(null) === true)
+    assert(instance1.eval(null) === false)
+    assert(instance2.eval(null) === true)
   }
 
 }
@@ -89,7 +92,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
  */
 case class NondeterministicExpression()
   extends LeafExpression with Nondeterministic with CodegenFallback {
-  override protected def initInternal(): Unit = { }
+  override protected def initializeInternal(partitionIndex: Int): Unit = {}
   override protected def evalInternal(input: InternalRow): Any = false
   override def nullable: Boolean = false
   override def dataType: DataType = BooleanType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index fdd1fa3648..e485b52b43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -71,8 +71,9 @@ case class RowDataSourceScanExec(
     val unsafeRow = if (outputUnsafeRows) {
       rdd
     } else {
-      rdd.mapPartitionsInternal { iter =>
+      rdd.mapPartitionsWithIndexInternal { (index, iter) =>
         val proj = UnsafeProjection.create(schema)
+        proj.initialize(index)
         iter.map(proj)
       }
     }
@@ -284,8 +285,9 @@ case class FileSourceScanExec(
       val unsafeRows = {
         val scan = inputRDD
         if (needsUnsafeRowConversion) {
-          scan.mapPartitionsInternal { iter =>
+          scan.mapPartitionsWithIndexInternal { (index, iter) =>
             val proj = UnsafeProjection.create(schema)
+            proj.initialize(index)
             iter.map(proj)
           }
         } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 455fb5bfbb..aab087cd98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -190,8 +190,9 @@ case class RDDScanExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-    rdd.mapPartitionsInternal { iter =>
+    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(schema)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 2663129562..19fbf0c162 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -94,8 +94,9 @@ case class GenerateExec(
     }
 
     val numOutputRows = longMetric("numOutputRows")
-    rows.mapPartitionsInternal { iter =>
+    rows.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(output, output)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 24d0cffef8..cadab37a44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -29,7 +29,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -354,7 +354,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   }
 
   protected def newPredicate(
-      expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
+      expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
     GeneratePredicate.generate(expression, inputSchema)
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6303483f22..516b9d5444 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -331,6 +331,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
           partitionIndex = index;
           this.inputs = inputs;
           ${ctx.initMutableStates()}
+          ${ctx.initPartition()}
         }
 
         ${ctx.declareAddedFunctions()}
@@ -383,10 +384,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
     } else {
       // Right now, we support up to two input RDDs.
       rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
-        val partitionIndex = TaskContext.getPartitionId()
+        Iterator((leftIter, rightIter))
+        // a small hack to obtain the correct partition index
+      }.mapPartitionsWithIndex { (index, zippedIter) =>
+        val (leftIter, rightIter) = zippedIter.next()
         val clazz = CodeGenerator.compile(cleanedSource)
         val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
-        buffer.init(partitionIndex, Array(leftIter, rightIter))
+        buffer.init(index, Array(leftIter, rightIter))
         new Iterator[InternalRow] {
           override def hasNext: Boolean = {
             val v = buffer.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index a5291e0c12..32133f5263 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -70,9 +70,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val project = UnsafeProjection.create(projectList, child.output,
         subexpressionEliminationEnabled)
+      project.initialize(index)
       iter.map(project)
     }
   }
@@ -205,10 +206,11 @@ case class FilterExec(condition: Expression, child: SparkPlan)
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val predicate = newPredicate(condition, child.output)
+      predicate.initialize(0)
       iter.filter { row =>
-        val r = predicate(row)
+        val r = predicate.eval(row)
         if (r) numOutputRows += 1
         r
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index b87016d5a5..9028caa446 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -132,10 +132,11 @@ case class InMemoryTableScanExec(
     val relOutput: AttributeSeq = relation.output
     val buffers = relation.cachedColumnBuffers
 
-    buffers.mapPartitionsInternal { cachedBatchIterator =>
+    buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
       val partitionFilter = newPredicate(
         partitionFilters.reduceOption(And).getOrElse(Literal(true)),
         schema)
+      partitionFilter.initialize(index)
 
       // Find the ordinals and data types of the requested columns.
       val (requestedColumnIndices, requestedColumnDataTypes) =
@@ -147,7 +148,7 @@ case class InMemoryTableScanExec(
       val cachedBatchesToScan =
         if (inMemoryPartitionPruningEnabled) {
           cachedBatchIterator.filter { cachedBatch =>
-            if (!partitionFilter(cachedBatch.stats)) {
+            if (!partitionFilter.eval(cachedBatch.stats)) {
               def statsString: String = schemaIndex.map {
                 case (a, i) =>
                   val value = cachedBatch.stats.get(i, a.dataType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index bfe7e3dea4..f526a19876 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -52,7 +52,7 @@ case class BroadcastNestedLoopJoinExec(
       UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
   }
 
-  private[this] def genResultProjection: InternalRow => InternalRow = joinType match {
+  private[this] def genResultProjection: UnsafeProjection = joinType match {
     case LeftExistence(j) =>
       UnsafeProjection.create(output, output)
     case other =>
@@ -84,7 +84,7 @@ case class BroadcastNestedLoopJoinExec(
 
   @transient private lazy val boundCondition = {
     if (condition.isDefined) {
-      newPredicate(condition.get, streamed.output ++ broadcast.output)
+      newPredicate(condition.get, streamed.output ++ broadcast.output).eval _
     } else {
       (r: InternalRow) => true
     }
@@ -366,8 +366,9 @@ case class BroadcastNestedLoopJoinExec(
     }
 
     val numOutputRows = longMetric("numOutputRows")
-    resultRdd.mapPartitionsInternal { iter =>
+    resultRdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val resultProj = genResultProjection
+      resultProj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         resultProj(r)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 15dc9b4066..8341fe2ffd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -98,15 +98,15 @@ case class CartesianProductExec(
     val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
 
     val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
-    pair.mapPartitionsInternal { iter =>
+    pair.mapPartitionsWithIndexInternal { (index, iter) =>
       val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
       val filtered = if (condition.isDefined) {
-        val boundCondition: (InternalRow) => Boolean =
-          newPredicate(condition.get, left.output ++ right.output)
+        val boundCondition = newPredicate(condition.get, left.output ++ right.output)
+        boundCondition.initialize(index)
         val joined = new JoinedRow
 
         iter.filter { r =>
-          boundCondition(joined(r._1, r._2))
+          boundCondition.eval(joined(r._1, r._2))
         }
       } else {
         iter
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 05c5e2f4cd..1aef5f6864 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -81,7 +81,7 @@ trait HashJoin {
     UnsafeProjection.create(streamedKeys)
 
   @transient private[this] lazy val boundCondition = if (condition.isDefined) {
-    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
+    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _
   } else {
     (r: InternalRow) => true
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ecf7cf289f..ca9c0ed8ce 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -101,7 +101,7 @@ case class SortMergeJoinExec(
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
       val boundCondition: (InternalRow) => Boolean = {
         condition.map { cond =>
-          newPredicate(cond, left.output ++ right.output)
+          newPredicate(cond, left.output ++ right.output).eval _
         }.getOrElse {
           (r: InternalRow) => true
         }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 9df56bbf1e..fde3b2a528 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -87,8 +87,9 @@ case class DeserializeToObjectExec(
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
+      projection.initialize(index)
       iter.map(projection)
     }
   }
@@ -124,8 +125,9 @@ case class SerializeFromObjectExec(
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val projection = UnsafeProjection.create(serializer)
+      projection.initialize(index)
       iter.map(projection)
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 586a0fffeb..0e9a2c6cf7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -19,7 +19,13 @@ package org.apache.spark.sql
 
 import java.nio.charset.StandardCharsets
 
+import scala.util.Random
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
@@ -406,4 +412,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       Seq(Row(true), Row(true))
     )
   }
+
+  private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
+    import DataFrameFunctionsSuite.CodegenFallbackExpr
+    for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
+      val c = if (codegenFallback) {
+        Column(CodegenFallbackExpr(v.expr))
+      } else {
+        v
+      }
+      withSQLConf(
+        (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString),
+        (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) {
+        val df = spark.range(0, 4, 1, 4).withColumn("c", c)
+        val rows = df.collect()
+        val rowsAfterCoalesce = df.coalesce(2).collect()
+        assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " +
+          s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.")
+
+        val df1 = spark.range(0, 2, 1, 2).withColumn("c", c)
+        val rows1 = df1.collect()
+        val df2 = spark.range(2, 4, 1, 2).withColumn("c", c)
+        val rows2 = df2.collect()
+        val rowsAfterUnion = df1.union(df2).collect()
+        assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " +
+          s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.")
+      }
+    }
+  }
+
+  test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " +
+    "coalesce or union") {
+    Seq(
+      monotonically_increasing_id(), spark_partition_id(),
+      rand(Random.nextLong()), randn(Random.nextLong())
+    ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
+  }
+}
+
+object DataFrameFunctionsSuite {
+  case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback {
+    override def children: Seq[Expression] = Seq(child)
+    override def nullable: Boolean = child.nullable
+    override def dataType: DataType = child.dataType
+    override lazy val resolved = true
+    override def eval(input: InternalRow): Any = child.eval(input)
+  }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 231f204b12..c80695bd3e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -154,8 +154,9 @@ case class HiveTableScanExec(
     val numOutputRows = longMetric("numOutputRows")
     // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649)
     val outputSchema = schema
-    rdd.mapPartitionsInternal { iter =>
+    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(outputSchema)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)
-- 
GitLab