diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index 524285bc87123a086ff0bfbc3c56ec7d04721325..a84e180ad1dd8ab6afae64025ad9bf1f30f4288c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -93,7 +93,7 @@ case class Expand(
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
     /*
      * When the projections list looks like:
      *   expr1A, exprB, expr1C
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 2ea889ea72c75080d6b5c5062913295ebfbdcd01..5a67cd0c24b44a5177f8e1f311ca1a3cf576faa5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -105,6 +105,8 @@ case class Sort(
   // Name of sorter variable used in codegen.
   private var sorterVariable: String = _
 
+  override def preferUnsafeRow: Boolean = true
+
   override protected def doProduce(ctx: CodegenContext): String = {
     val needToSort = ctx.freshName("needToSort")
     ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
@@ -153,18 +155,22 @@ case class Sort(
      """.stripMargin.trim
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
-    val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
-      BoundReference(i, attr.dataType, attr.nullable)
-    }
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
+    if (row != null) {
+      s"$sorterVariable.insertRow((UnsafeRow)$row);"
+    } else {
+      val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
+        BoundReference(i, attr.dataType, attr.nullable)
+      }
 
-    ctx.currentVars = input
-    val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
+      ctx.currentVars = input
+      val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
 
-    s"""
-       | // Convert the input attributes to an UnsafeRow and add it to the sorter
-       | ${code.code}
-       | $sorterVariable.insertRow(${code.value});
-     """.stripMargin.trim
+      s"""
+         | // Convert the input attributes to an UnsafeRow and add it to the sorter
+         | ${code.code}
+         | $sorterVariable.insertRow(${code.value});
+       """.stripMargin.trim
+    }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index dd831e60cbf5e48f5cf8850798ac1d6ba96bb503..e8e42d72d4e66600a15974dfc9d83b0c02c4ab74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -65,7 +65,12 @@ trait CodegenSupport extends SparkPlan {
   /**
     * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan.
     */
-  private var parent: CodegenSupport = null
+  protected var parent: CodegenSupport = null
+
+  /**
+    * Whether this SparkPlan prefers to accept UnsafeRow as input in doConsume.
+    */
+  def preferUnsafeRow: Boolean = false
 
   /**
     * Returns all the RDDs of InternalRow which generates the input rows.
@@ -176,11 +181,20 @@ trait CodegenSupport extends SparkPlan {
       } else {
         input
       }
+
+    val evaluated =
+      if (row != null && preferUnsafeRow) {
+        // Current plan can consume UnsafeRows directly.
+        ""
+      } else {
+        evaluateRequiredVariables(child.output, inputVars, usedInputs)
+      }
+
     s"""
        |
        |/*** CONSUME: ${toCommentSafeString(this.simpleString)} */
-       |${evaluateRequiredVariables(child.output, inputVars, usedInputs)}
-       |${doConsume(ctx, inputVars)}
+       |${evaluated}
+       |${doConsume(ctx, inputVars, row)}
      """.stripMargin
   }
 
@@ -195,7 +209,7 @@ trait CodegenSupport extends SparkPlan {
     *   if (isNull1 || !value2) continue;
     *   # call consume(), which will call parent.doConsume()
     */
-  protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  protected def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
     throw new UnsupportedOperationException
   }
 }
@@ -238,7 +252,7 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport
     s"""
        | while (!shouldStop() && $input.hasNext()) {
        |   InternalRow $row = (InternalRow) $input.next();
-       |   ${consume(ctx, columns).trim}
+       |   ${consume(ctx, columns, row).trim}
        | }
      """.stripMargin
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index f856634cf7b66d4d61112bd1cc8cc633f9669354..1c4d594cd863ee0e300b78c1c878efc2c1b24b55 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -139,7 +139,7 @@ case class TungstenAggregate(
     }
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
     if (groupingExpressions.isEmpty) {
       doConsumeWithoutKeys(ctx, input)
     } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 49012982273af271a2a5d6933c34bd7131462b1b..6ebbc8be6f8bc8159a6998505ff34cf25e0f25c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -49,7 +49,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
     references.filter(a => usedMoreThanOnce.contains(a.exprId))
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
     val exprs = projectList.map(x =>
       ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output)))
     ctx.currentVars = input
@@ -88,7 +88,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode wit
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
     val numOutput = metricTerm(ctx, "numOutputRows")
     val expr = ExpressionCanonicalizer.execute(
       BindReferences.bindReference(condition, child.output))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index fed88b8c0a1179f0678c6f9f6e0720ff600b7543..034bf152620ded4f3965ad7d94677b39b7fc605c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -136,7 +136,7 @@ package object debug {
       child.asInstanceOf[CodegenSupport].produce(ctx, this)
     }
 
-    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+    override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
       consume(ctx, input)
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index c52662a61e7f82862ef12a44471b8e829e5ddcdb..4c8f8080a98d7ac4ac7a2ee1fd46e613ac60beab 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -107,7 +107,7 @@ case class BroadcastHashJoin(
     streamedPlan.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
     if (joinType == Inner) {
       codegenInner(ctx, input)
     } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 5a7516b7f9c72beccb84ee5bb2c23d46d100091f..ca624a5a84e6c9804c7277d62a534d56f612595d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -65,7 +65,7 @@ trait BaseLimit extends UnaryNode with CodegenSupport {
     child.asInstanceOf[CodegenSupport].produce(ctx, this)
   }
 
-  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+  override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: String): String = {
     val stopEarly = ctx.freshName("stopEarly")
     ctx.addMutableState("boolean", stopEarly, s"$stopEarly = false;")