diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 4727ff1885ad786c6b180c7d10964765f13cbdd7..72fe06545910c6b8ac562eb51a44a70bcf078028 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -62,9 +62,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)
     val javaType = ctx.javaType(dataType)
     val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString)
     if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) {
-      ev.isNull = ctx.currentVars(ordinal).isNull
-      ev.value = ctx.currentVars(ordinal).value
-      ""
+      val oev = ctx.currentVars(ordinal)
+      ev.isNull = oev.isNull
+      ev.value = oev.value
+      oev.code
     } else if (nullable) {
       s"""
         boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 63e19564dd86178a549a56d5358b6b0d53d29821..c4265a753933fec8faa5ad187a6492b580610c99 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -37,6 +37,8 @@ import org.apache.spark.util.Utils
  * Java source for evaluating an [[Expression]] given a [[InternalRow]] of input.
  *
  * @param code The sequence of statements required to evaluate the expression.
+ *             It should be empty string, if `isNull` and `value` are already existed, or no code
+ *             needed to evaluate them (literals).
  * @param isNull A term that holds a boolean value representing whether the expression evaluated
  *                 to null.
  * @param value A term for a (possibly primitive) value of the result of the evaluation. Not
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 4ad07508ca429dbcf2fbae7b47a817d6bded349d..3662ed74d2eae0e7209492fdf18cbf034c73a2b6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -151,9 +151,6 @@ private[sql] case class PhysicalRDD(
     val exprs = output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, true))
     val row = ctx.freshName("row")
     val numOutputRows = metricTerm(ctx, "numOutputRows")
-    ctx.INPUT_ROW = row
-    ctx.currentVars = null
-    val columns = exprs.map(_.gen(ctx))
 
     // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this
     // by looking at the first value of the RDD and then calling the function which will process
@@ -161,7 +158,9 @@ private[sql] case class PhysicalRDD(
     // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
     // here which path to use. Fix this.
 
-
+    ctx.INPUT_ROW = row
+    ctx.currentVars = null
+    val columns1 = exprs.map(_.gen(ctx))
     val scanBatches = ctx.freshName("processBatches")
     ctx.addNewFunction(scanBatches,
       s"""
@@ -170,12 +169,11 @@ private[sql] case class PhysicalRDD(
       |     int numRows = $batch.numRows();
       |     if ($idx == 0) $numOutputRows.add(numRows);
       |
-      |     while ($idx < numRows) {
+      |     while (!shouldStop() && $idx < numRows) {
       |       InternalRow $row = $batch.getRow($idx++);
-      |       ${columns.map(_.code).mkString("\n").trim}
-      |       ${consume(ctx, columns).trim}
-      |       if (shouldStop()) return;
+      |       ${consume(ctx, columns1).trim}
       |     }
+      |     if (shouldStop()) return;
       |
       |     if (!$input.hasNext()) {
       |       $batch = null;
@@ -186,30 +184,37 @@ private[sql] case class PhysicalRDD(
       |   }
       | }""".stripMargin)
 
+    ctx.INPUT_ROW = row
+    ctx.currentVars = null
+    val columns2 = exprs.map(_.gen(ctx))
+    val inputRow = if (isUnsafeRow) row else null
     val scanRows = ctx.freshName("processRows")
     ctx.addNewFunction(scanRows,
       s"""
        | private void $scanRows(InternalRow $row) throws java.io.IOException {
-       |   while (true) {
+       |   boolean firstRow = true;
+       |   while (!shouldStop() && (firstRow || $input.hasNext())) {
+       |     if (firstRow) {
+       |       firstRow = false;
+       |     } else {
+       |       $row = (InternalRow) $input.next();
+       |     }
        |     $numOutputRows.add(1);
-       |     ${columns.map(_.code).mkString("\n").trim}
-       |     ${consume(ctx, columns).trim}
-       |     if (shouldStop()) return;
-       |     if (!$input.hasNext()) break;
-       |     $row = (InternalRow)$input.next();
+       |     ${consume(ctx, columns2, inputRow).trim}
        |   }
        | }""".stripMargin)
 
+    val value = ctx.freshName("value")
     s"""
        | if ($batch != null) {
        |   $scanBatches();
        | } else if ($input.hasNext()) {
-       |   Object value = $input.next();
-       |   if (value instanceof $columnarBatchClz) {
-       |     $batch = ($columnarBatchClz)value;
+       |   Object $value = $input.next();
+       |   if ($value instanceof $columnarBatchClz) {
+       |     $batch = ($columnarBatchClz)$value;
        |     $scanBatches();
        |   } else {
-       |     $scanRows((InternalRow)value);
+       |     $scanRows((InternalRow) $value);
        |   }
        | }
      """.stripMargin
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 12998a38f59e7018f6ea7c00b5aac5646e55f225..524285bc87123a086ff0bfbc3c56ec7d04721325 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -185,8 +185,10 @@ case class Expand(
 
     val numOutput = metricTerm(ctx, "numOutputRows")
     val i = ctx.freshName("i")
+    // these column have to declared before the loop.
+    val evaluate = evaluateVariables(outputColumns)
     s"""
-       |${outputColumns.map(_.code).mkString("\n").trim}
+       |$evaluate
        |for (int $i = 0; $i < ${projections.length}; $i ++) {
        |  switch ($i) {
        |    ${cases.mkString("\n").trim}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 6d231bf74a0e9b21ec54cbbc7cf1f0c938201b1e..45578d50bfc0d1690cd88c4d249e4068038f5cbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -81,11 +81,14 @@ trait CodegenSupport extends SparkPlan {
     this.parent = parent
     ctx.freshNamePrefix = variablePrefix
     waitForSubqueries()
-    doProduce(ctx)
+    s"""
+       |/*** PRODUCE: ${toCommentSafeString(this.simpleString)} */
+       |${doProduce(ctx)}
+     """.stripMargin
   }
 
   /**
-    * Generate the Java source code to process, should be overrided by subclass to support codegen.
+    * Generate the Java source code to process, should be overridden by subclass to support codegen.
     *
     * doProduce() usually generate the framework, for example, aggregation could generate this:
     *
@@ -94,11 +97,11 @@ trait CodegenSupport extends SparkPlan {
     *     # call child.produce()
     *     initialized = true;
     *   }
-    *   while (hashmap.hasNext()) {
+    *   while (!shouldStop() && hashmap.hasNext()) {
     *     row = hashmap.next();
     *     # build the aggregation results
-    *     # create varialbles for results
-    *     # call consume(), wich will call parent.doConsume()
+    *     # create variables for results
+    *     # call consume(), which will call parent.doConsume()
     *   }
     */
   protected def doProduce(ctx: CodegenContext): String
@@ -114,27 +117,71 @@ trait CodegenSupport extends SparkPlan {
   }
 
   /**
-    * Consume the columns generated from it's child, call doConsume() or emit the rows.
+    * Returns source code to evaluate all the variables, and clear the code of them, to prevent
+    * them to be evaluated twice.
+    */
+  protected def evaluateVariables(variables: Seq[ExprCode]): String = {
+    val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n")
+    variables.foreach(_.code = "")
+    evaluate
+  }
+
+  /**
+    * Returns source code to evaluate the variables for required attributes, and clear the code
+    * of evaluated variables, to prevent them to be evaluated twice..
     */
+  protected def evaluateRequiredVariables(
+      attributes: Seq[Attribute],
+      variables: Seq[ExprCode],
+      required: AttributeSet): String = {
+    var evaluateVars = ""
+    variables.zipWithIndex.foreach { case (ev, i) =>
+      if (ev.code != "" && required.contains(attributes(i))) {
+        evaluateVars += ev.code.trim + "\n"
+        ev.code = ""
+      }
+    }
+    evaluateVars
+  }
+
+  /**
+   * The subset of inputSet those should be evaluated before this plan.
+   *
+   * We will use this to insert some code to access those columns that are actually used by current
+   * plan before calling doConsume().
+   */
+  def usedInputs: AttributeSet = references
+
+  /**
+   * Consume the columns generated from its child, call doConsume() or emit the rows.
+   *
+   * An operator could generate variables for the output, or a row, either one could be null.
+   *
+   * If the row is not null, we create variables to access the columns that are actually used by
+   * current plan before calling doConsume().
+   */
   def consumeChild(
       ctx: CodegenContext,
       child: SparkPlan,
       input: Seq[ExprCode],
       row: String = null): String = {
     ctx.freshNamePrefix = variablePrefix
-    if (row != null) {
-      ctx.currentVars = null
-      ctx.INPUT_ROW = row
-      val evals = child.output.zipWithIndex.map { case (attr, i) =>
-        BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+    val inputVars =
+      if (row != null) {
+        ctx.currentVars = null
+        ctx.INPUT_ROW = row
+        child.output.zipWithIndex.map { case (attr, i) =>
+          BoundReference(i, attr.dataType, attr.nullable).gen(ctx)
+        }
+      } else {
+        input
       }
-      s"""
-         | ${evals.map(_.code).mkString("\n")}
-         | ${doConsume(ctx, evals)}
-       """.stripMargin
-    } else {
-      doConsume(ctx, input)
-    }
+    s"""
+       |
+       |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
+       |${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
+       |${doConsume(ctx, inputVars)}
+     """.stripMargin
   }
 
   /**
@@ -145,9 +192,8 @@ trait CodegenSupport extends SparkPlan {
     * For example, Filter will generate the code like this:
     *
     *   # code to evaluate the predicate expression, result is isNull1 and value2
-    *   if (isNull1 || value2) {
-    *     # call consume(), which will call parent.doConsume()
-    *   }
+    *   if (isNull1 || !value2) continue;
+    *   # call consume(), which will call parent.doConsume()
     */
   protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     throw new UnsupportedOperationException
@@ -190,13 +236,9 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
     ctx.currentVars = null
     val columns = exprs.map(_.gen(ctx))
     s"""
-       | while ($input.hasNext()) {
+       | while (!shouldStop() && $input.hasNext()) {
        |   InternalRow $row = (InternalRow) $input.next();
-       |   ${columns.map(_.code).mkString("\n").trim}
        |   ${consume(ctx, columns).trim}
-       |   if (shouldStop()) {
-       |     return;
-       |   }
        | }
      """.stripMargin
   }
@@ -332,10 +374,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
         val colExprs = output.zipWithIndex.map { case (attr, i) =>
           BoundReference(i, attr.dataType, attr.nullable)
         }
+        val evaluateInputs = evaluateVariables(input)
         // generate the code to create a UnsafeRow
         ctx.currentVars = input
         val code = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
         s"""
+           |$evaluateInputs
            |${code.code.trim}
            |append(${code.value}.copy());
          """.stripMargin.trim
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a46722963a6e17c853154ffa0966efaa53cbc3d1..f07add83d5849b0069f107ba9b39b3a93b4e0350 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -116,6 +116,8 @@ case class TungstenAggregate(
   // all the mode of aggregate expressions
   private val modes = aggregateExpressions.map(_.mode).distinct
 
+  override def usedInputs: AttributeSet = inputSet
+
   override def supportCodegen: Boolean = {
     // ImperativeAggregate is not supported right now
     !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
@@ -164,23 +166,24 @@ case class TungstenAggregate(
        """.stripMargin
       ExprCode(ev.code + initVars, isNull, value)
     }
+    val initBufVar = evaluateVariables(bufVars)
 
     // generate variables for output
-    val bufferAttrs = functions.flatMap(_.aggBufferAttributes)
     val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) {
       // evaluate aggregate results
       ctx.currentVars = bufVars
       val aggResults = functions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, bufferAttrs).gen(ctx)
+        BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
       }
+      val evaluateAggResults = evaluateVariables(aggResults)
       // evaluate result expressions
       ctx.currentVars = aggResults
       val resultVars = resultExpressions.map { e =>
         BindReferences.bindReference(e, aggregateAttributes).gen(ctx)
       }
       (resultVars, s"""
-        | ${aggResults.map(_.code).mkString("\n")}
-        | ${resultVars.map(_.code).mkString("\n")}
+        |$evaluateAggResults
+        |${evaluateVariables(resultVars)}
        """.stripMargin)
     } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
       // output the aggregate buffer directly
@@ -188,7 +191,7 @@ case class TungstenAggregate(
     } else {
       // no aggregate function, the result should be literals
       val resultVars = resultExpressions.map(_.gen(ctx))
-      (resultVars, resultVars.map(_.code).mkString("\n"))
+      (resultVars, evaluateVariables(resultVars))
     }
 
     val doAgg = ctx.freshName("doAggregateWithoutKey")
@@ -196,7 +199,7 @@ case class TungstenAggregate(
       s"""
          | private void $doAgg() throws java.io.IOException {
          |   // initialize aggregation buffer
-         |   ${bufVars.map(_.code).mkString("\n")}
+         |   $initBufVar
          |
          |   ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
          | }
@@ -204,7 +207,7 @@ case class TungstenAggregate(
 
     val numOutput = metricTerm(ctx, "numOutputRows")
     s"""
-       | if (!$initAgg) {
+       | while (!$initAgg) {
        |   $initAgg = true;
        |   $doAgg();
        |
@@ -241,7 +244,7 @@ case class TungstenAggregate(
     }
     s"""
        | // do aggregate
-       | ${aggVals.map(_.code).mkString("\n").trim}
+       | ${evaluateVariables(aggVals)}
        | // update aggregation buffer
        | ${updates.mkString("\n").trim}
      """.stripMargin
@@ -252,8 +255,7 @@ case class TungstenAggregate(
   private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
     .filter(_.isInstanceOf[DeclarativeAggregate])
     .map(_.asInstanceOf[DeclarativeAggregate])
-  private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
-  private val bufferSchema = StructType.fromAttributes(bufferAttributes)
+  private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
 
   // The name for HashMap
   private var hashMapTerm: String = _
@@ -318,7 +320,7 @@ case class TungstenAggregate(
       val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
       val mergeProjection = newMutableProjection(
         mergeExpr,
-        bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
+        aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
         subexpressionEliminationEnabled)()
       val joinedRow = new JoinedRow()
 
@@ -380,15 +382,18 @@ case class TungstenAggregate(
       val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
         BoundReference(i, e.dataType, e.nullable).gen(ctx)
       }
+      val evaluateKeyVars = evaluateVariables(keyVars)
       ctx.INPUT_ROW = bufferTerm
-      val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+      val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
         BoundReference(i, e.dataType, e.nullable).gen(ctx)
       }
+      val evaluateBufferVars = evaluateVariables(bufferVars)
       // evaluate the aggregation result
       ctx.currentVars = bufferVars
       val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
-        BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+        BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx)
       }
+      val evaluateAggResults = evaluateVariables(aggResults)
       // generate the final result
       ctx.currentVars = keyVars ++ aggResults
       val inputAttrs = groupingAttributes ++ aggregateAttributes
@@ -396,11 +401,9 @@ case class TungstenAggregate(
         BindReferences.bindReference(e, inputAttrs).gen(ctx)
       }
       s"""
-       ${keyVars.map(_.code).mkString("\n")}
-       ${bufferVars.map(_.code).mkString("\n")}
-       ${aggResults.map(_.code).mkString("\n")}
-       ${resultVars.map(_.code).mkString("\n")}
-
+       $evaluateKeyVars
+       $evaluateBufferVars
+       $evaluateAggResults
        ${consume(ctx, resultVars)}
        """
 
@@ -422,10 +425,7 @@ case class TungstenAggregate(
       val eval = resultExpressions.map{ e =>
         BindReferences.bindReference(e, groupingAttributes).gen(ctx)
       }
-      s"""
-       ${eval.map(_.code).mkString("\n")}
-       ${consume(ctx, eval)}
-       """
+      consume(ctx, eval)
     }
   }
 
@@ -508,8 +508,8 @@ case class TungstenAggregate(
     ctx.currentVars = input
     val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx)
 
-    val inputAttr = bufferAttributes ++ child.output
-    ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
+    val inputAttr = aggregateBufferAttributes ++ child.output
+    ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input
     ctx.INPUT_ROW = buffer
     // TODO: support subexpression elimination
     val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
@@ -557,7 +557,7 @@ case class TungstenAggregate(
      $incCounter
 
      // evaluate aggregate function
-     ${evals.map(_.code).mkString("\n").trim}
+     ${evaluateVariables(evals)}
      // update aggregate buffer
      ${updates.mkString("\n").trim}
      """
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b2f443c0e9ae67d87b6bdb2cefee51074aad79c3..4a9e736f7abdb4d73125402f7b44ab21ea081e1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -39,15 +39,26 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
+  override def usedInputs: AttributeSet = {
+    // only the attributes those are used at least twice should be evaluated before this plan,
+    // otherwise we could defer the evaluation until output attribute is actually used.
+    val usedExprIds = projectList.flatMap(_.collect {
+      case a: Attribute => a.exprId
+    })
+    val usedMoreThanOnce = usedExprIds.groupBy(id => id).filter(_._2.size > 1).keySet
+    references.filter(a => usedMoreThanOnce.contains(a.exprId))
+  }
+
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
     val exprs = projectList.map(x =>
       ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
     ctx.currentVars = input
-    val output = exprs.map(_.gen(ctx))
+    val resultVars = exprs.map(_.gen(ctx))
+    // Evaluation of non-deterministic expressions can't be deferred.
+    val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute)
     s"""
-       | ${output.map(_.code).mkString("\n")}
-       |
-       | ${consume(ctx, output)}
+       |${evaluateRequiredVariables(output, resultVars, AttributeSet(nonDeterministicAttrs))}
+       |${consume(ctx, resultVars)}
      """.stripMargin
   }
 
@@ -89,11 +100,10 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
       s""
     }
     s"""
-       | ${eval.code}
-       | if ($nullCheck ${eval.value}) {
-       |   $numOutput.add(1);
-       |   ${consume(ctx, ctx.currentVars)}
-       | }
+       |${eval.code}
+       |if (!($nullCheck ${eval.value})) continue;
+       |$numOutput.add(1);
+       |${consume(ctx, ctx.currentVars)}
      """.stripMargin
   }
 
@@ -228,15 +238,13 @@ case class Range(
       |   }
       | }
       |
-      | while (!$overflow && $checkEnd) {
+      | while (!$overflow && $checkEnd && !shouldStop()) {
       |  long $value = $number;
       |  $number += ${step}L;
       |  if ($number < $value ^ ${step}L < 0) {
       |    $overflow = true;
       |  }
       |  ${consume(ctx, Seq(ev))}
-      |
-      |  if (shouldStop()) return;
       | }
      """.stripMargin
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index 6699dbafe7e74ced41d62621d0296e56d6824967..c52662a61e7f82862ef12a44471b8e829e5ddcdb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -190,40 +190,38 @@ case class BroadcastHashJoin(
     val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
     val matched = ctx.freshName("matched")
     val buildVars = genBuildSideVars(ctx, matched)
-    val resultVars = buildSide match {
-      case BuildLeft => buildVars ++ input
-      case BuildRight => input ++ buildVars
-    }
     val numOutput = metricTerm(ctx, "numOutputRows")
 
-    val outputCode = if (condition.isDefined) {
+    val checkCondition = if (condition.isDefined) {
+      val expr = condition.get
+      // evaluate the variables from build side that used by condition
+      val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
       // filter the output via condition
-      ctx.currentVars = resultVars
-      val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+      ctx.currentVars = input ++ buildVars
+      val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
       s"""
+         |$eval
          |${ev.code}
-         |if (!${ev.isNull} && ${ev.value}) {
-         |  $numOutput.add(1);
-         |  ${consume(ctx, resultVars)}
-         |}
+         |if (${ev.isNull} || !${ev.value}) continue;
        """.stripMargin
     } else {
-      s"""
-         |$numOutput.add(1);
-         |${consume(ctx, resultVars)}
-       """.stripMargin
+      ""
     }
 
+    val resultVars = buildSide match {
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
+    }
     if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
          |// find matches from HashedRelation
          |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
-         |if ($matched != null) {
-         |  ${buildVars.map(_.code).mkString("\n")}
-         |  $outputCode
-         |}
+         |if ($matched == null) continue;
+         |$checkCondition
+         |$numOutput.add(1);
+         |${consume(ctx, resultVars)}
        """.stripMargin
 
     } else {
@@ -236,13 +234,13 @@ case class BroadcastHashJoin(
          |${keyEv.code}
          |// find matches from HashRelation
          |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value});
-         |if ($matches != null) {
-         |  int $size = $matches.size();
-         |  for (int $i = 0; $i < $size; $i++) {
-         |    UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
-         |    ${buildVars.map(_.code).mkString("\n")}
-         |    $outputCode
-         |  }
+         |if ($matches == null) continue;
+         |int $size = $matches.size();
+         |for (int $i = 0; $i < $size; $i++) {
+         |  UnsafeRow $matched = (UnsafeRow) $matches.apply($i);
+         |  $checkCondition
+         |  $numOutput.add(1);
+         |  ${consume(ctx, resultVars)}
          |}
        """.stripMargin
     }
@@ -257,21 +255,21 @@ case class BroadcastHashJoin(
     val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
     val matched = ctx.freshName("matched")
     val buildVars = genBuildSideVars(ctx, matched)
-    val resultVars = buildSide match {
-      case BuildLeft => buildVars ++ input
-      case BuildRight => input ++ buildVars
-    }
     val numOutput = metricTerm(ctx, "numOutputRows")
 
     // filter the output via condition
     val conditionPassed = ctx.freshName("conditionPassed")
     val checkCondition = if (condition.isDefined) {
-      ctx.currentVars = resultVars
-      val ev = BindReferences.bindReference(condition.get, this.output).gen(ctx)
+      val expr = condition.get
+      // evaluate the variables from build side that used by condition
+      val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references)
+      ctx.currentVars = input ++ buildVars
+      val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx)
       s"""
          |boolean $conditionPassed = true;
+         |${eval.trim}
+         |${ev.code}
          |if ($matched != null) {
-         |  ${ev.code}
          |  $conditionPassed = !${ev.isNull} && ${ev.value};
          |}
        """.stripMargin
@@ -279,17 +277,21 @@ case class BroadcastHashJoin(
       s"final boolean $conditionPassed = true;"
     }
 
+    val resultVars = buildSide match {
+      case BuildLeft => buildVars ++ input
+      case BuildRight => input ++ buildVars
+    }
     if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) {
       s"""
          |// generate join key for stream side
          |${keyEv.code}
          |// find matches from HashedRelation
          |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value});
-         |${buildVars.map(_.code).mkString("\n")}
          |${checkCondition.trim}
          |if (!$conditionPassed) {
-         |  // reset to null
-         |  ${buildVars.map(v => s"${v.isNull} = true;").mkString("\n")}
+         |  $matched = null;
+         |  // reset the variables those are already evaluated.
+         |  ${buildVars.filter(_.code == "").map(v => s"${v.isNull} = true;").mkString("\n")}
          |}
          |$numOutput.add(1);
          |${consume(ctx, resultVars)}
@@ -311,13 +313,11 @@ case class BroadcastHashJoin(
          |// the last iteration of this loop is to emit an empty row if there is no matched rows.
          |for (int $i = 0; $i <= $size; $i++) {
          |  UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null;
-         |  ${buildVars.map(_.code).mkString("\n")}
          |  ${checkCondition.trim}
-         |  if ($conditionPassed && ($i < $size || !$found)) {
-         |    $found = true;
-         |    $numOutput.add(1);
-         |    ${consume(ctx, resultVars)}
-         |  }
+         |  if (!$conditionPassed || ($i == $size && $found)) continue;
+         |  $found = true;
+         |  $numOutput.add(1);
+         |  ${consume(ctx, resultVars)}
          |}
        """.stripMargin
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
index 7ec4027188f145a1e3812f574ef26845354cbdad..cffd6f6032f2f2ba781a3ac8d52618a4c77a02bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala
@@ -306,11 +306,11 @@ case class SortMergeJoin(
       val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) =>
         condRefs.contains(a)
       }
-      val beforeCond = used.map(_._2.code).mkString("\n")
-      val afterCond = notUsed.map(_._2.code).mkString("\n")
+      val beforeCond = evaluateVariables(used.map(_._2))
+      val afterCond = evaluateVariables(notUsed.map(_._2))
       (beforeCond, afterCond)
     } else {
-      (variables.map(_.code).mkString("\n"), "")
+      (evaluateVariables(variables), "")
     }
   }
 
@@ -326,41 +326,48 @@ case class SortMergeJoin(
     val leftVars = createLeftVars(ctx, leftRow)
     val rightRow = ctx.freshName("rightRow")
     val rightVars = createRightVar(ctx, rightRow)
-    val resultVars = leftVars ++ rightVars
-
-    // Check condition
-    ctx.currentVars = resultVars
-    val cond = if (condition.isDefined) {
-      BindReferences.bindReference(condition.get, output).gen(ctx)
-    } else {
-      ExprCode("", "false", "true")
-    }
-    // Split the code of creating variables based on whether it's used by condition or not.
-    val loaded = ctx.freshName("loaded")
-    val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
-    val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
-
 
     val size = ctx.freshName("size")
     val i = ctx.freshName("i")
     val numOutput = metricTerm(ctx, "numOutputRows")
+    val (beforeLoop, condCheck) = if (condition.isDefined) {
+      // Split the code of creating variables based on whether it's used by condition or not.
+      val loaded = ctx.freshName("loaded")
+      val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars)
+      val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars)
+      // Generate code for condition
+      ctx.currentVars = leftVars ++ rightVars
+      val cond = BindReferences.bindReference(condition.get, output).gen(ctx)
+      // evaluate the columns those used by condition before loop
+      val before = s"""
+           |boolean $loaded = false;
+           |$leftBefore
+         """.stripMargin
+
+      val checking = s"""
+         |$rightBefore
+         |${cond.code}
+         |if (${cond.isNull} || !${cond.value}) continue;
+         |if (!$loaded) {
+         |  $loaded = true;
+         |  $leftAfter
+         |}
+         |$rightAfter
+     """.stripMargin
+      (before, checking)
+    } else {
+      (evaluateVariables(leftVars), "")
+    }
+
     s"""
        |while (findNextInnerJoinRows($leftInput, $rightInput)) {
        |  int $size = $matches.size();
-       |  boolean $loaded = false;
-       |  $leftBefore
+       |  ${beforeLoop.trim}
        |  for (int $i = 0; $i < $size; $i ++) {
        |    InternalRow $rightRow = (InternalRow) $matches.get($i);
-       |    $rightBefore
-       |    ${cond.code}
-       |    if (${cond.isNull} || !${cond.value}) continue;
-       |    if (!$loaded) {
-       |      $loaded = true;
-       |      $leftAfter
-       |    }
-       |    $rightAfter
+       |    ${condCheck.trim}
        |    $numOutput.add(1);
-       |    ${consume(ctx, resultVars)}
+       |    ${consume(ctx, leftVars ++ rightVars)}
        |  }
        |  if (shouldStop()) return;
        |}