diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index db535de9e9bb34b446e35ff3107d284e5943d148..e018af35cb18d5d44edc5c25bc93a294271770d6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -788,14 +788,26 @@ abstract class RDD[T: ClassTag](
   }
 
   /**
-   * [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
-   * performance API to be used carefully only if we are sure that the RDD elements are
+   * [performance] Spark's internal mapPartitionsWithIndex method that skips closure cleaning.
+   * It is a performance API to be used carefully only if we are sure that the RDD elements are
    * serializable and don't require closure cleaning.
    *
    * @param preservesPartitioning indicates whether the input function preserves the partitioner,
    * which should be `false` unless this is a pair RDD and the input function doesn't modify
    * the keys.
    */
+  private[spark] def mapPartitionsWithIndexInternal[U: ClassTag](
+      f: (Int, Iterator[T]) => Iterator[U],
+      preservesPartitioning: Boolean = false): RDD[U] = withScope {
+    new MapPartitionsRDD(
+      this,
+      (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
+      preservesPartitioning)
+  }
+
+  /**
+   * [performance] Spark's internal mapPartitions method that skips closure cleaning.
+   */
   private[spark] def mapPartitionsInternal[U: ClassTag](
       f: Iterator[T] => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = withScope {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 9edc1ceff26a7b23c9b5a5dbb9f40dc2da1bb022..726a231fd814ef0017dd1e9e7b88d4febae7ef7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -272,17 +272,28 @@ trait Nondeterministic extends Expression {
   final override def deterministic: Boolean = false
   final override def foldable: Boolean = false
 
+  @transient
   private[this] var initialized = false
 
-  final def setInitialValues(): Unit = {
-    initInternal()
+  /**
+   * Initializes internal states given the current partition index and mark this as initialized.
+   * Subclasses should override [[initializeInternal()]].
+   */
+  final def initialize(partitionIndex: Int): Unit = {
+    initializeInternal(partitionIndex)
     initialized = true
   }
 
-  protected def initInternal(): Unit
+  protected def initializeInternal(partitionIndex: Int): Unit
 
+  /**
+   * @inheritdoc
+   * Throws an exception if [[initialize()]] is not called yet.
+   * Subclasses should override [[evalInternal()]].
+   */
   final override def eval(input: InternalRow = null): Any = {
-    require(initialized, "nondeterministic expression should be initialized before evaluate")
+    require(initialized,
+      s"Nondeterministic expression ${this.getClass.getName} should be initialized before eval.")
     evalInternal(input)
   }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index 96929ecf56375d2620164879799c7c4e6b08f031..b6c12c535111922ece06c23dcf6ef698fd0ac7ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -37,7 +37,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
 
   override def prettyName: String = "input_file_name"
 
-  override protected def initInternal(): Unit = {}
+  override protected def initializeInternal(partitionIndex: Int): Unit = {}
 
   override protected def evalInternal(input: InternalRow): UTF8String = {
     InputFileNameHolder.getInputFileName()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 5b4922e0cf2b7581bbf1b164297836d4498047ae..72b8dcca26e2f64c0cf3f93288818d786ee07f82 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -50,9 +50,9 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
 
   @transient private[this] var partitionMask: Long = _
 
-  override protected def initInternal(): Unit = {
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
     count = 0L
-    partitionMask = TaskContext.getPartitionId().toLong << 33
+    partitionMask = partitionIndex.toLong << 33
   }
 
   override def nullable: Boolean = false
@@ -68,9 +68,10 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterminis
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val countTerm = ctx.freshName("count")
     val partitionMaskTerm = ctx.freshName("partitionMask")
-    ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;")
-    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm,
-      s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;")
+    ctx.addMutableState(ctx.JAVA_LONG, countTerm, "")
+    ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, "")
+    ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
+    ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;")
 
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 03e054d0985112632b4edab1ea1e953e8e140470..476e37e6a9bacc0e9187a6402fb5dd3c8d69c173 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.types.{DataType, StructType}
 
 /**
  * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
+ *
  * @param expressions a sequence of expressions that determine the value of each column of the
  *                    output row.
  */
@@ -30,10 +31,12 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
     this(expressions.map(BindReferences.bindReference(_, inputSchema)))
 
-  expressions.foreach(_.foreach {
-    case n: Nondeterministic => n.setInitialValues()
-    case _ =>
-  })
+  override def initialize(partitionIndex: Int): Unit = {
+    expressions.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
 
   // null check is required for when Kryo invokes the no-arg constructor.
   protected val exprArray = if (expressions != null) expressions.toArray else null
@@ -54,6 +57,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
 /**
  * A [[MutableProjection]] that is calculated by calling `eval` on each of the specified
  * expressions.
+ *
  * @param expressions a sequence of expressions that determine the value of each column of the
  *                    output row.
  */
@@ -63,10 +67,12 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
 
   private[this] val buffer = new Array[Any](expressions.size)
 
-  expressions.foreach(_.foreach {
-    case n: Nondeterministic => n.setInitialValues()
-    case _ =>
-  })
+  override def initialize(partitionIndex: Int): Unit = {
+    expressions.foreach(_.foreach {
+      case n: Nondeterministic => n.initialize(partitionIndex)
+      case _ =>
+    })
+  }
 
   private[this] val exprArray = expressions.toArray
   private[this] var mutableRow: InternalRow = new GenericInternalRow(exprArray.length)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 1f675d5b07270f6f8aa0f9edd229772340927617..6bef473cac060be29e20b41404bc1a3760778feb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -17,16 +17,15 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
-import org.apache.spark.TaskContext
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.types.{DataType, IntegerType}
 
 /**
- * Expression that returns the current partition id of the Spark task.
+ * Expression that returns the current partition id.
  */
 @ExpressionDescription(
-  usage = "_FUNC_() - Returns the current partition id of the Spark task",
+  usage = "_FUNC_() - Returns the current partition id",
   extended = "> SELECT _FUNC_();\n 0")
 case class SparkPartitionID() extends LeafExpression with Nondeterministic {
 
@@ -38,16 +37,16 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic {
 
   override val prettyName = "SPARK_PARTITION_ID"
 
-  override protected def initInternal(): Unit = {
-    partitionId = TaskContext.getPartitionId()
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
+    partitionId = partitionIndex
   }
 
   override protected def evalInternal(input: InternalRow): Int = partitionId
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val idTerm = ctx.freshName("partitionId")
-    ctx.addMutableState(ctx.JAVA_INT, idTerm,
-      s"$idTerm = org.apache.spark.TaskContext.getPartitionId();")
+    ctx.addMutableState(ctx.JAVA_INT, idTerm, "")
+    ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
     ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false")
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 6cab50ae1bf8dbfa496e38eacc62db751a49138c..9c3c6d3b2a7f21a9e5187ef485f9870675c47119 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -184,6 +184,20 @@ class CodegenContext {
     splitExpressions(initCodes, "init", Nil)
   }
 
+  /**
+   * Code statements to initialize states that depend on the partition index.
+   * An integer `partitionIndex` will be made available within the scope.
+   */
+  val partitionInitializationStatements: mutable.ArrayBuffer[String] = mutable.ArrayBuffer.empty
+
+  def addPartitionInitializationStatement(statement: String): Unit = {
+    partitionInitializationStatements += statement
+  }
+
+  def initPartition(): String = {
+    partitionInitializationStatements.mkString("\n")
+  }
+
   /**
    * Holding all the functions those will be added into generated class.
    */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
index 6a5a3e7933eea1784e58d4b4c6ff79bc484ff3e8..0322d1dd6a9fffd6d453c9d28d9892c8c150852a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala
@@ -25,15 +25,23 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, No
 trait CodegenFallback extends Expression {
 
   protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    foreach {
-      case n: Nondeterministic => n.setInitialValues()
-      case _ =>
-    }
-
     // LeafNode does not need `input`
     val input = if (this.isInstanceOf[LeafExpression]) "null" else ctx.INPUT_ROW
     val idx = ctx.references.length
     ctx.references += this
+    var childIndex = idx
+    this.foreach {
+      case n: Nondeterministic =>
+        // This might add the current expression twice, but it won't hurt.
+        ctx.references += n
+        childIndex += 1
+        ctx.addPartitionInitializationStatement(
+          s"""
+             |((Nondeterministic) references[$childIndex])
+             |  .initialize(partitionIndex);
+          """.stripMargin)
+      case _ =>
+    }
     val objectTerm = ctx.freshName("obj")
     val placeHolder = ctx.registerComment(this.toString)
     if (nullable) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index 5c4b56b0b224cb55d131e8c44b0b79e42b333e70..4d732445544a862400a7dda1407646100a84b211 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -111,6 +111,10 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public ${classOf[BaseMutableProjection].getName} target(InternalRow row) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
index 39aa7b17de6c9f1262eba509fe3a326da5599a79..dcd1ed96a298e2825df1200cccd816af96c5165d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala
@@ -25,19 +25,26 @@ import org.apache.spark.sql.catalyst.expressions._
  */
 abstract class Predicate {
   def eval(r: InternalRow): Boolean
+
+  /**
+   * Initializes internal states given the current partition index.
+   * This is used by nondeterministic expressions to set initial states.
+   * The default implementation does nothing.
+   */
+  def initialize(partitionIndex: Int): Unit = {}
 }
 
 /**
  * Generates bytecode that evaluates a boolean [[Expression]] on a given input [[InternalRow]].
  */
-object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Boolean] {
+object GeneratePredicate extends CodeGenerator[Expression, Predicate] {
 
   protected def canonicalize(in: Expression): Expression = ExpressionCanonicalizer.execute(in)
 
   protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression =
     BindReferences.bindReference(in, inputSchema)
 
-  protected def create(predicate: Expression): ((InternalRow) => Boolean) = {
+  protected def create(predicate: Expression): Predicate = {
     val ctx = newCodeGenContext()
     val eval = predicate.genCode(ctx)
 
@@ -55,6 +62,10 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public boolean eval(InternalRow ${ctx.INPUT_ROW}) {
@@ -67,7 +78,6 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
       new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
     logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
 
-    val p = CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
-    (r: InternalRow) => p.eval(r)
+    CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index 2773e1a6662123311e94813093d42f6fd7c9acce..b1cb6edefb852df1ec7e0899d1b7fc34661bba2b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -173,6 +173,10 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         public java.lang.Object apply(java.lang.Object _i) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 7cc45372daa5ad5c8eb84bd99751c713a7b5f6a5..7e4c9089a2cb99a2e5549cb6cd1b5fc0468d62e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -380,6 +380,10 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
           ${ctx.initMutableStates()}
         }
 
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
         ${ctx.declareAddedFunctions()}
 
         // Scala.Function1 need this
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 1510a4796683c350980ea74fe693a926a9c42f15..1b00c9e79da22e496ff07642f3fcae221bfc1ff8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -64,7 +64,15 @@ package object expressions  {
    * column of the new row. If the schema of the input row is specified, then the given expression
    * will be bound to that schema.
    */
-  abstract class Projection extends (InternalRow => InternalRow)
+  abstract class Projection extends (InternalRow => InternalRow) {
+
+    /**
+     * Initializes internal states given the current partition index.
+     * This is used by nondeterministic expressions to set initial states.
+     * The default implementation does nothing.
+     */
+    def initialize(partitionIndex: Int): Unit = {}
+  }
 
   /**
    * Converts a [[InternalRow]] to another Row given a sequence of expression that define each
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 9394e39aadd9d4bdaa50580483463a7d445c9375..c941a576d00d69fe1d1e59337f3ef72af89137e8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -31,10 +31,6 @@ object InterpretedPredicate {
     create(BindReferences.bindReference(expression, inputSchema))
 
   def create(expression: Expression): (InternalRow => Boolean) = {
-    expression.foreach {
-      case n: Nondeterministic => n.setInitialValues()
-      case _ =>
-    }
     (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
   }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
index ca200768b22863338d32adb36e0ebc56eb1746f2..e09029f5aab9b30d94ac92d3b04b7ccd166ca61b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
@@ -42,8 +42,8 @@ abstract class RDG extends LeafExpression with Nondeterministic {
    */
   @transient protected var rng: XORShiftRandom = _
 
-  override protected def initInternal(): Unit = {
-    rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
+  override protected def initializeInternal(partitionIndex: Int): Unit = {
+    rng = new XORShiftRandom(seed + partitionIndex)
   }
 
   override def nullable: Boolean = false
@@ -70,8 +70,9 @@ case class Rand(seed: Long) extends RDG {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
     val className = classOf[XORShiftRandom].getName
-    ctx.addMutableState(className, rngTerm,
-      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
+    ctx.addMutableState(className, rngTerm, "")
+    ctx.addPartitionInitializationStatement(
+      s"$rngTerm = new $className(${seed}L + partitionIndex);")
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
   }
@@ -93,8 +94,9 @@ case class Randn(seed: Long) extends RDG {
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val rngTerm = ctx.freshName("rng")
     val className = classOf[XORShiftRandom].getName
-    ctx.addMutableState(className, rngTerm,
-      s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());")
+    ctx.addMutableState(className, rngTerm, "")
+    ctx.addPartitionInitializationStatement(
+      s"$rngTerm = new $className(${seed}L + partitionIndex);")
     ev.copy(code = s"""
       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index e5e2cd7d27d15378bcb3a9a513ce7078032abdb6..b6ad5db74e3c842526577f7baf0a9148a8b613f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1060,6 +1060,7 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
     case Project(projectList, LocalRelation(output, data))
         if !projectList.exists(hasUnevaluableExpr) =>
       val projection = new InterpretedProjection(projectList, output)
+      projection.initialize(0)
       LocalRelation(projectList.map(_.toAttribute), data.map(projection))
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index f0c149c02b9aa7654ceb5c00240ec9dfe6c54378..9ceb709185417a8e8839f425a9bcd78efefedab7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -75,7 +75,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
 
   protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
     expression.foreach {
-      case n: Nondeterministic => n.setInitialValues()
+      case n: Nondeterministic => n.initialize(0)
       case _ =>
     }
     expression.eval(inputRow)
@@ -121,6 +121,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     val plan = generateProject(
       GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
 
     val actual = plan(inputRow).get(0, expression.dataType)
     if (!checkResult(actual, expected)) {
@@ -182,12 +183,14 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
     var plan = generateProject(
       GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
     var actual = plan(inputRow).get(0, expression.dataType)
     assert(checkResult(actual, expected))
 
     plan = generateProject(
       GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil),
       expression)
+    plan.initialize(0)
     actual = FromUnsafeProjection(expression.dataType :: Nil)(
       plan(inputRow)).get(0, expression.dataType)
     assert(checkResult(actual, expected))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
index 06dc3bd33b90e3f2a744a09fbf30cd334d50a9b0..fe5cb8eda824f8645a86be742495036b1cea9a1e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala
@@ -31,19 +31,22 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
     // Use an Add to wrap two of them together in case we only initialize the top level expressions.
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = UnsafeProjection.create(Seq(expr))
+    instance.initialize(0)
     assert(instance.apply(null).getBoolean(0) === false)
   }
 
   test("GenerateMutableProjection should initialize expressions") {
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = GenerateMutableProjection.generate(Seq(expr))
+    instance.initialize(0)
     assert(instance.apply(null).getBoolean(0) === false)
   }
 
   test("GeneratePredicate should initialize expressions") {
     val expr = And(NondeterministicExpression(), NondeterministicExpression())
     val instance = GeneratePredicate.generate(expr)
-    assert(instance.apply(null) === false)
+    instance.initialize(0)
+    assert(instance.eval(null) === false)
   }
 
   test("GenerateUnsafeProjection should not share expression instances") {
@@ -73,13 +76,13 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
   test("GeneratePredicate should not share expression instances") {
     val expr1 = MutableExpression()
     val instance1 = GeneratePredicate.generate(expr1)
-    assert(instance1.apply(null) === false)
+    assert(instance1.eval(null) === false)
 
     val expr2 = MutableExpression()
     expr2.mutableState = true
     val instance2 = GeneratePredicate.generate(expr2)
-    assert(instance1.apply(null) === false)
-    assert(instance2.apply(null) === true)
+    assert(instance1.eval(null) === false)
+    assert(instance2.eval(null) === true)
   }
 
 }
@@ -89,7 +92,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
  */
 case class NondeterministicExpression()
   extends LeafExpression with Nondeterministic with CodegenFallback {
-  override protected def initInternal(): Unit = { }
+  override protected def initializeInternal(partitionIndex: Int): Unit = {}
   override protected def evalInternal(input: InternalRow): Any = false
   override def nullable: Boolean = false
   override def dataType: DataType = BooleanType
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index fdd1fa3648251baaf9f8bc4d041646fc45f7c44e..e485b52b43f76bdcc20f2421bdfa4957a51b19e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -71,8 +71,9 @@ case class RowDataSourceScanExec(
     val unsafeRow = if (outputUnsafeRows) {
       rdd
     } else {
-      rdd.mapPartitionsInternal { iter =>
+      rdd.mapPartitionsWithIndexInternal { (index, iter) =>
         val proj = UnsafeProjection.create(schema)
+        proj.initialize(index)
         iter.map(proj)
       }
     }
@@ -284,8 +285,9 @@ case class FileSourceScanExec(
       val unsafeRows = {
         val scan = inputRDD
         if (needsUnsafeRowConversion) {
-          scan.mapPartitionsInternal { iter =>
+          scan.mapPartitionsWithIndexInternal { (index, iter) =>
             val proj = UnsafeProjection.create(schema)
+            proj.initialize(index)
             iter.map(proj)
           }
         } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 455fb5bfbb6f772e75fccc05239392937467b9ff..aab087cd98716c6d6c5160e6188728eb872e5607 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -190,8 +190,9 @@ case class RDDScanExec(
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-    rdd.mapPartitionsInternal { iter =>
+    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(schema)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index 266312956266030d89657ed06f729d7309f93a44..19fbf0c162048e0086a5ffd9a1e45fb6e9cfddec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -94,8 +94,9 @@ case class GenerateExec(
     }
 
     val numOutputRows = longMetric("numOutputRows")
-    rows.mapPartitionsInternal { iter =>
+    rows.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(output, output)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 24d0cffef82a21ccbfd3a4fb792d005d2014ff31..cadab37a449aae933c764b8f7f50323f99166839 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -29,7 +29,7 @@ import org.apache.spark.rdd.{RDD, RDDOperationScope}
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetric
@@ -354,7 +354,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   }
 
   protected def newPredicate(
-      expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = {
+      expression: Expression, inputSchema: Seq[Attribute]): GenPredicate = {
     GeneratePredicate.generate(expression, inputSchema)
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 6303483f22fd34a91526b2f1e463888da8168f7d..516b9d5444d31815ceec4f67f68ff8510062d80a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -331,6 +331,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
           partitionIndex = index;
           this.inputs = inputs;
           ${ctx.initMutableStates()}
+          ${ctx.initPartition()}
         }
 
         ${ctx.declareAddedFunctions()}
@@ -383,10 +384,13 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
     } else {
       // Right now, we support up to two input RDDs.
       rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
-        val partitionIndex = TaskContext.getPartitionId()
+        Iterator((leftIter, rightIter))
+        // a small hack to obtain the correct partition index
+      }.mapPartitionsWithIndex { (index, zippedIter) =>
+        val (leftIter, rightIter) = zippedIter.next()
         val clazz = CodeGenerator.compile(cleanedSource)
         val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
-        buffer.init(partitionIndex, Array(leftIter, rightIter))
+        buffer.init(index, Array(leftIter, rightIter))
         new Iterator[InternalRow] {
           override def hasNext: Boolean = {
             val v = buffer.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index a5291e0c12f881cce300ba22ffbcf4a104e8d875..32133f52630cdad7d2750663acf53d4e21ab2dbd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -70,9 +70,10 @@ case class ProjectExec(projectList: Seq[NamedExpression], child: SparkPlan)
   }
 
   protected override def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val project = UnsafeProjection.create(projectList, child.output,
         subexpressionEliminationEnabled)
+      project.initialize(index)
       iter.map(project)
     }
   }
@@ -205,10 +206,11 @@ case class FilterExec(condition: Expression, child: SparkPlan)
 
   protected override def doExecute(): RDD[InternalRow] = {
     val numOutputRows = longMetric("numOutputRows")
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val predicate = newPredicate(condition, child.output)
+      predicate.initialize(0)
       iter.filter { row =>
-        val r = predicate(row)
+        val r = predicate.eval(row)
         if (r) numOutputRows += 1
         r
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
index b87016d5a5696db8b1e4fe98fdca323679fd6c6a..9028caa446e8cb95fd8fbe5ce77e531f86550655 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala
@@ -132,10 +132,11 @@ case class InMemoryTableScanExec(
     val relOutput: AttributeSeq = relation.output
     val buffers = relation.cachedColumnBuffers
 
-    buffers.mapPartitionsInternal { cachedBatchIterator =>
+    buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) =>
       val partitionFilter = newPredicate(
         partitionFilters.reduceOption(And).getOrElse(Literal(true)),
         schema)
+      partitionFilter.initialize(index)
 
       // Find the ordinals and data types of the requested columns.
       val (requestedColumnIndices, requestedColumnDataTypes) =
@@ -147,7 +148,7 @@ case class InMemoryTableScanExec(
       val cachedBatchesToScan =
         if (inMemoryPartitionPruningEnabled) {
           cachedBatchIterator.filter { cachedBatch =>
-            if (!partitionFilter(cachedBatch.stats)) {
+            if (!partitionFilter.eval(cachedBatch.stats)) {
               def statsString: String = schemaIndex.map {
                 case (a, i) =>
                   val value = cachedBatch.stats.get(i, a.dataType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
index bfe7e3dea45df26ff44007257d36bea9a603a3de..f526a19876670f0b50b00f0b9a696e43c27237eb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala
@@ -52,7 +52,7 @@ case class BroadcastNestedLoopJoinExec(
       UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
   }
 
-  private[this] def genResultProjection: InternalRow => InternalRow = joinType match {
+  private[this] def genResultProjection: UnsafeProjection = joinType match {
     case LeftExistence(j) =>
       UnsafeProjection.create(output, output)
     case other =>
@@ -84,7 +84,7 @@ case class BroadcastNestedLoopJoinExec(
 
   @transient private lazy val boundCondition = {
     if (condition.isDefined) {
-      newPredicate(condition.get, streamed.output ++ broadcast.output)
+      newPredicate(condition.get, streamed.output ++ broadcast.output).eval _
     } else {
       (r: InternalRow) => true
     }
@@ -366,8 +366,9 @@ case class BroadcastNestedLoopJoinExec(
     }
 
     val numOutputRows = longMetric("numOutputRows")
-    resultRdd.mapPartitionsInternal { iter =>
+    resultRdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val resultProj = genResultProjection
+      resultProj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         resultProj(r)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 15dc9b40662e2b32ddb59e64270d14d425db7bad..8341fe2ffd078234849e38166ed3f7ade28435db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -98,15 +98,15 @@ case class CartesianProductExec(
     val rightResults = right.execute().asInstanceOf[RDD[UnsafeRow]]
 
     val pair = new UnsafeCartesianRDD(leftResults, rightResults, right.output.size)
-    pair.mapPartitionsInternal { iter =>
+    pair.mapPartitionsWithIndexInternal { (index, iter) =>
       val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema)
       val filtered = if (condition.isDefined) {
-        val boundCondition: (InternalRow) => Boolean =
-          newPredicate(condition.get, left.output ++ right.output)
+        val boundCondition = newPredicate(condition.get, left.output ++ right.output)
+        boundCondition.initialize(index)
         val joined = new JoinedRow
 
         iter.filter { r =>
-          boundCondition(joined(r._1, r._2))
+          boundCondition.eval(joined(r._1, r._2))
         }
       } else {
         iter
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 05c5e2f4cd77bda29fd5b033dd4eeee137a7aa85..1aef5f6864263e1c894f0970830127e10b9b882d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -81,7 +81,7 @@ trait HashJoin {
     UnsafeProjection.create(streamedKeys)
 
   @transient private[this] lazy val boundCondition = if (condition.isDefined) {
-    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output)
+    newPredicate(condition.get, streamedPlan.output ++ buildPlan.output).eval _
   } else {
     (r: InternalRow) => true
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index ecf7cf289f034b917983b7ce4796bf3a58a17aeb..ca9c0ed8cec32e01e836a765542876164121cf30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -101,7 +101,7 @@ case class SortMergeJoinExec(
     left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
       val boundCondition: (InternalRow) => Boolean = {
         condition.map { cond =>
-          newPredicate(cond, left.output ++ right.output)
+          newPredicate(cond, left.output ++ right.output).eval _
         }.getOrElse {
           (r: InternalRow) => true
         }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 9df56bbf1ef87f1aca57da6be8a7dd09026983bd..fde3b2a528994a3c2f7a97e3dddbb9b071d0d17c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -87,8 +87,9 @@ case class DeserializeToObjectExec(
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
+      projection.initialize(index)
       iter.map(projection)
     }
   }
@@ -124,8 +125,9 @@ case class SerializeFromObjectExec(
   }
 
   override protected def doExecute(): RDD[InternalRow] = {
-    child.execute().mapPartitionsInternal { iter =>
+    child.execute().mapPartitionsWithIndexInternal { (index, iter) =>
       val projection = UnsafeProjection.create(serializer)
+      projection.initialize(index)
       iter.map(projection)
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 586a0fffeb7a1f5448209dc068e4d1ba2109d1d5..0e9a2c6cf7dec6e44d82f10f67da80413da8711e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -19,7 +19,13 @@ package org.apache.spark.sql
 
 import java.nio.charset.StandardCharsets
 
+import scala.util.Random
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types._
 
@@ -406,4 +412,50 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
       Seq(Row(true), Row(true))
     )
   }
+
+  private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
+    import DataFrameFunctionsSuite.CodegenFallbackExpr
+    for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
+      val c = if (codegenFallback) {
+        Column(CodegenFallbackExpr(v.expr))
+      } else {
+        v
+      }
+      withSQLConf(
+        (SQLConf.WHOLESTAGE_FALLBACK.key, codegenFallback.toString),
+        (SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, wholeStage.toString)) {
+        val df = spark.range(0, 4, 1, 4).withColumn("c", c)
+        val rows = df.collect()
+        val rowsAfterCoalesce = df.coalesce(2).collect()
+        assert(rows === rowsAfterCoalesce, "Values changed after coalesce when " +
+          s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.")
+
+        val df1 = spark.range(0, 2, 1, 2).withColumn("c", c)
+        val rows1 = df1.collect()
+        val df2 = spark.range(2, 4, 1, 2).withColumn("c", c)
+        val rows2 = df2.collect()
+        val rowsAfterUnion = df1.union(df2).collect()
+        assert(rowsAfterUnion === rows1 ++ rows2, "Values changed after union when " +
+          s"codegenFallback=$codegenFallback and wholeStage=$wholeStage.")
+      }
+    }
+  }
+
+  test("SPARK-14393: values generated by non-deterministic functions shouldn't change after " +
+    "coalesce or union") {
+    Seq(
+      monotonically_increasing_id(), spark_partition_id(),
+      rand(Random.nextLong()), randn(Random.nextLong())
+    ).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
+  }
+}
+
+object DataFrameFunctionsSuite {
+  case class CodegenFallbackExpr(child: Expression) extends Expression with CodegenFallback {
+    override def children: Seq[Expression] = Seq(child)
+    override def nullable: Boolean = child.nullable
+    override def dataType: DataType = child.dataType
+    override lazy val resolved = true
+    override def eval(input: InternalRow): Any = child.eval(input)
+  }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 231f204b12b47939e074e814fe830756ee4d6bfe..c80695bd3e0feb95d54232e7fdd9feaa8fe2fb70 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -154,8 +154,9 @@ case class HiveTableScanExec(
     val numOutputRows = longMetric("numOutputRows")
     // Avoid to serialize MetastoreRelation because schema is lazy. (see SPARK-15649)
     val outputSchema = schema
-    rdd.mapPartitionsInternal { iter =>
+    rdd.mapPartitionsWithIndexInternal { (index, iter) =>
       val proj = UnsafeProjection.create(outputSchema)
+      proj.initialize(index)
       iter.map { r =>
         numOutputRows += 1
         proj(r)