diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 03a79520cbd3a87a5467fadec0f323159027dae5..57575f9ee09abfaa87e7c5896ad6594498bf17ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -25,6 +25,17 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
 
+/**
+ * When planning take() or collect() operations, this special node that is inserted at the top of
+ * the logical plan before invoking the query planner.
+ *
+ * Rules can pattern-match on this node in order to apply transformations that only take effect
+ * at the top of the logical query plan.
+ */
+case class ReturnAnswer(child: LogicalPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = child.output
+}
+
 case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = projectList.map(_.toAttribute)
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 3770883af1e2f078478726cf718a4cdf595a02ed..97f65f18bfdcc4bb3d433c282022ee9172c0143d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -57,6 +57,69 @@ case class Exchange(
 
   override def output: Seq[Attribute] = child.output
 
+  private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+
+  override protected def doPrepare(): Unit = {
+    // If an ExchangeCoordinator is needed, we register this Exchange operator
+    // to the coordinator when we do prepare. It is important to make sure
+    // we register this operator right before the execution instead of register it
+    // in the constructor because it is possible that we create new instances of
+    // Exchange operators when we transform the physical plan
+    // (then the ExchangeCoordinator will hold references of unneeded Exchanges).
+    // So, we should only call registerExchange just before we start to execute
+    // the plan.
+    coordinator match {
+      case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
+      case None =>
+    }
+  }
+
+  /**
+   * Returns a [[ShuffleDependency]] that will partition rows of its child based on
+   * the partitioning scheme defined in `newPartitioning`. Those partitions of
+   * the returned ShuffleDependency will be the input of shuffle.
+   */
+  private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
+    Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer)
+  }
+
+  /**
+   * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
+   * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
+   * partition start indices array. If this optional array is defined, the returned
+   * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
+   */
+  private[sql] def preparePostShuffleRDD(
+      shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
+      specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
+    // If an array of partition start indices is provided, we need to use this array
+    // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
+    // update the number of post-shuffle partitions.
+    specifiedPartitionStartIndices.foreach { indices =>
+      assert(newPartitioning.isInstanceOf[HashPartitioning])
+      newPartitioning = UnknownPartitioning(indices.length)
+    }
+    new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
+  }
+
+  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
+    coordinator match {
+      case Some(exchangeCoordinator) =>
+        val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
+        assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
+        shuffleRDD
+      case None =>
+        val shuffleDependency = prepareShuffleDependency()
+        preparePostShuffleRDD(shuffleDependency)
+    }
+  }
+}
+
+object Exchange {
+  def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
+    Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
+  }
+
   /**
    * Determines whether records must be defensively copied before being sent to the shuffle.
    * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
@@ -82,7 +145,7 @@ case class Exchange(
     // passed instead of directly passing the number of partitions in order to guard against
     // corner-cases where a partitioner constructed with `numPartitions` partitions may output
     // fewer partitions (like RangePartitioner, for example).
-    val conf = child.sqlContext.sparkContext.conf
+    val conf = SparkEnv.get.conf
     val shuffleManager = SparkEnv.get.shuffleManager
     val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
     val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
@@ -117,30 +180,16 @@ case class Exchange(
     }
   }
 
-  private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
-
-  override protected def doPrepare(): Unit = {
-    // If an ExchangeCoordinator is needed, we register this Exchange operator
-    // to the coordinator when we do prepare. It is important to make sure
-    // we register this operator right before the execution instead of register it
-    // in the constructor because it is possible that we create new instances of
-    // Exchange operators when we transform the physical plan
-    // (then the ExchangeCoordinator will hold references of unneeded Exchanges).
-    // So, we should only call registerExchange just before we start to execute
-    // the plan.
-    coordinator match {
-      case Some(exchangeCoordinator) => exchangeCoordinator.registerExchange(this)
-      case None =>
-    }
-  }
-
   /**
    * Returns a [[ShuffleDependency]] that will partition rows of its child based on
    * the partitioning scheme defined in `newPartitioning`. Those partitions of
    * the returned ShuffleDependency will be the input of shuffle.
    */
-  private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
-    val rdd = child.execute()
+  private[sql] def prepareShuffleDependency(
+      rdd: RDD[InternalRow],
+      outputAttributes: Seq[Attribute],
+      newPartitioning: Partitioning,
+      serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
     val part: Partitioner = newPartitioning match {
       case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
       case HashPartitioning(_, n) =>
@@ -160,7 +209,7 @@ case class Exchange(
         // We need to use an interpreted ordering here because generated orderings cannot be
         // serialized and this ordering needs to be created on the driver in order to be passed into
         // Spark core code.
-        implicit val ordering = new InterpretedOrdering(sortingExpressions, child.output)
+        implicit val ordering = new InterpretedOrdering(sortingExpressions, outputAttributes)
         new RangePartitioner(numPartitions, rddForSampling, ascending = true)
       case SinglePartition =>
         new Partitioner {
@@ -180,7 +229,7 @@ case class Exchange(
           position
         }
       case h: HashPartitioning =>
-        val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output)
+        val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, outputAttributes)
         row => projection(row).getInt(0)
       case RangePartitioning(_, _) | SinglePartition => identity
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")
@@ -211,43 +260,6 @@ case class Exchange(
 
     dependency
   }
-
-  /**
-   * Returns a [[ShuffledRowRDD]] that represents the post-shuffle dataset.
-   * This [[ShuffledRowRDD]] is created based on a given [[ShuffleDependency]] and an optional
-   * partition start indices array. If this optional array is defined, the returned
-   * [[ShuffledRowRDD]] will fetch pre-shuffle partitions based on indices of this array.
-   */
-  private[sql] def preparePostShuffleRDD(
-      shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow],
-      specifiedPartitionStartIndices: Option[Array[Int]] = None): ShuffledRowRDD = {
-    // If an array of partition start indices is provided, we need to use this array
-    // to create the ShuffledRowRDD. Also, we need to update newPartitioning to
-    // update the number of post-shuffle partitions.
-    specifiedPartitionStartIndices.foreach { indices =>
-      assert(newPartitioning.isInstanceOf[HashPartitioning])
-      newPartitioning = UnknownPartitioning(indices.length)
-    }
-    new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
-  }
-
-  protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
-    coordinator match {
-      case Some(exchangeCoordinator) =>
-        val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
-        assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
-        shuffleRDD
-      case None =>
-        val shuffleDependency = prepareShuffleDependency()
-        preparePostShuffleRDD(shuffleDependency)
-    }
-  }
-}
-
-object Exchange {
-  def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
-    Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
-  }
 }
 
 /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 107570f9dbcc84cf6b4ac80f24fa71f68e2fc751..8616fe317034fe894d2c7c2c0231436c17380032 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer}
 
 /**
  * The primary workflow for executing relational queries using Spark.  Designed to allow easy
@@ -44,7 +44,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
 
   lazy val sparkPlan: SparkPlan = {
     SQLContext.setActive(sqlContext)
-    sqlContext.planner.plan(optimizedPlan).next()
+    sqlContext.planner.plan(ReturnAnswer(optimizedPlan)).next()
   }
 
   // executedPlan should not be used to initialize any SparkPlan. It should be
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 830bb011beab4750a07b4a399210b71dcd217104..ee392e4e8debe90ea9bbbe9dbeb818027d120d42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -338,8 +338,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         execution.Sample(lb, ub, withReplacement, seed, planLater(child)) :: Nil
       case logical.LocalRelation(output, data) =>
         LocalTableScan(output, data) :: Nil
+      case logical.ReturnAnswer(logical.Limit(IntegerLiteral(limit), child)) =>
+        execution.CollectLimit(limit, planLater(child)) :: Nil
       case logical.Limit(IntegerLiteral(limit), child) =>
-        execution.Limit(limit, planLater(child)) :: Nil
+        val perPartitionLimit = execution.LocalLimit(limit, planLater(child))
+        val globalLimit = execution.GlobalLimit(limit, perPartitionLimit)
+        globalLimit :: Nil
       case logical.Union(unionChildren) =>
         execution.Union(unionChildren.map(planLater)) :: Nil
       case logical.Except(left, right) =>
@@ -358,6 +362,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
         BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
       case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
       case BroadcastHint(child) => planLater(child) :: Nil
+      case logical.ReturnAnswer(child) => planLater(child) :: Nil
       case _ => Nil
     }
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6e51c4d84824a3b09b2170215173d62ae50c11f0..f63e8a9b6d79ddfa3b5077c9e8ca9f58f61af36d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -17,16 +17,13 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.{HashPartitioner, SparkEnv}
-import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
-import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.LongType
-import org.apache.spark.util.MutablePair
 import org.apache.spark.util.random.PoissonSampler
 
 case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
@@ -306,96 +303,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan {
     sparkContext.union(children.map(_.execute()))
 }
 
-/**
- * Take the first limit elements. Note that the implementation is different depending on whether
- * this is a terminal operator or not. If it is terminal and is invoked using executeCollect,
- * this operator uses something similar to Spark's take method on the Spark driver. If it is not
- * terminal or is invoked using execute, we first take the limit on each partition, and then
- * repartition all the data to a single partition to compute the global limit.
- */
-case class Limit(limit: Int, child: SparkPlan)
-  extends UnaryNode {
-  // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
-  // partition local limit -> exchange into one partition -> partition local limit again
-
-  /** We must copy rows when sort based shuffle is on */
-  private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
-
-  override def output: Seq[Attribute] = child.output
-  override def outputPartitioning: Partitioning = SinglePartition
-
-  override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
-
-  protected override def doExecute(): RDD[InternalRow] = {
-    val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) {
-      child.execute().mapPartitionsInternal { iter =>
-        iter.take(limit).map(row => (false, row.copy()))
-      }
-    } else {
-      child.execute().mapPartitionsInternal { iter =>
-        val mutablePair = new MutablePair[Boolean, InternalRow]()
-        iter.take(limit).map(row => mutablePair.update(false, row))
-      }
-    }
-    val part = new HashPartitioner(1)
-    val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part)
-    shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
-    shuffled.mapPartitionsInternal(_.take(limit).map(_._2))
-  }
-}
-
-/**
- * Take the first limit elements as defined by the sortOrder, and do projection if needed.
- * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
- * or having a [[Project]] operator between them.
- * This could have been named TopK, but Spark's top operator does the opposite in ordering
- * so we name it TakeOrdered to avoid confusion.
- */
-case class TakeOrderedAndProject(
-    limit: Int,
-    sortOrder: Seq[SortOrder],
-    projectList: Option[Seq[NamedExpression]],
-    child: SparkPlan) extends UnaryNode {
-
-  override def output: Seq[Attribute] = {
-    val projectOutput = projectList.map(_.map(_.toAttribute))
-    projectOutput.getOrElse(child.output)
-  }
-
-  override def outputPartitioning: Partitioning = SinglePartition
-
-  // We need to use an interpreted ordering here because generated orderings cannot be serialized
-  // and this ordering needs to be created on the driver in order to be passed into Spark core code.
-  private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
-
-  private def collectData(): Array[InternalRow] = {
-    val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
-    if (projectList.isDefined) {
-      val proj = UnsafeProjection.create(projectList.get, child.output)
-      data.map(r => proj(r).copy())
-    } else {
-      data
-    }
-  }
-
-  override def executeCollect(): Array[InternalRow] = {
-    collectData()
-  }
-
-  // TODO: Terminal split should be implemented differently from non-terminal split.
-  // TODO: Pick num splits based on |limit|.
-  protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
-
-  override def outputOrdering: Seq[SortOrder] = sortOrder
-
-  override def simpleString: String = {
-    val orderByString = sortOrder.mkString("[", ",", "]")
-    val outputString = output.mkString("[", ",", "]")
-
-    s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
-  }
-}
-
 /**
  * Return a new RDD that has exactly `numPartitions` partitions.
  * Similar to coalesce defined on an [[RDD]], this operation results in a narrow dependency, e.g.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
new file mode 100644
index 0000000000000000000000000000000000000000..256f4228ae99e05666dc75b1421096d2a41ec2cf
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+
+
+/**
+ * Take the first `limit` elements and collect them to a single partition.
+ *
+ * This operator will be used when a logical `Limit` operation is the final operator in an
+ * logical plan, which happens when the user is collecting results back to the driver.
+ */
+case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
+  override def output: Seq[Attribute] = child.output
+  override def outputPartitioning: Partitioning = SinglePartition
+  override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
+  private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
+  protected override def doExecute(): RDD[InternalRow] = {
+    val shuffled = new ShuffledRowRDD(
+      Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer))
+    shuffled.mapPartitionsInternal(_.take(limit))
+  }
+}
+
+/**
+ * Helper trait which defines methods that are shared by both [[LocalLimit]] and [[GlobalLimit]].
+ */
+trait BaseLimit extends UnaryNode {
+  val limit: Int
+  override def output: Seq[Attribute] = child.output
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+  protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
+    iter.take(limit)
+  }
+}
+
+/**
+ * Take the first `limit` elements of each child partition, but do not collect or shuffle them.
+ */
+case class LocalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+}
+
+/**
+ * Take the first `limit` elements of the child's single output partition.
+ */
+case class GlobalLimit(limit: Int, child: SparkPlan) extends BaseLimit {
+  override def requiredChildDistribution: List[Distribution] = AllTuples :: Nil
+}
+
+/**
+ * Take the first limit elements as defined by the sortOrder, and do projection if needed.
+ * This is logically equivalent to having a Limit operator after a [[Sort]] operator,
+ * or having a [[Project]] operator between them.
+ * This could have been named TopK, but Spark's top operator does the opposite in ordering
+ * so we name it TakeOrdered to avoid confusion.
+ */
+case class TakeOrderedAndProject(
+    limit: Int,
+    sortOrder: Seq[SortOrder],
+    projectList: Option[Seq[NamedExpression]],
+    child: SparkPlan) extends UnaryNode {
+
+  override def output: Seq[Attribute] = {
+    val projectOutput = projectList.map(_.map(_.toAttribute))
+    projectOutput.getOrElse(child.output)
+  }
+
+  override def outputPartitioning: Partitioning = SinglePartition
+
+  // We need to use an interpreted ordering here because generated orderings cannot be serialized
+  // and this ordering needs to be created on the driver in order to be passed into Spark core code.
+  private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output)
+
+  private def collectData(): Array[InternalRow] = {
+    val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+    if (projectList.isDefined) {
+      val proj = UnsafeProjection.create(projectList.get, child.output)
+      data.map(r => proj(r).copy())
+    } else {
+      data
+    }
+  }
+
+  override def executeCollect(): Array[InternalRow] = {
+    collectData()
+  }
+
+  // TODO: Terminal split should be implemented differently from non-terminal split.
+  // TODO: Pick num splits based on |limit|.
+  protected override def doExecute(): RDD[InternalRow] = sparkContext.makeRDD(collectData(), 1)
+
+  override def outputOrdering: Seq[SortOrder] = sortOrder
+
+  override def simpleString: String = {
+    val orderByString = sortOrder.mkString("[", ",", "]")
+    val outputString = output.mkString("[", ",", "]")
+
+    s"TakeOrderedAndProject(limit=$limit, orderBy=$orderByString, output=$outputString)"
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index adaeb513bc1b8094bec80ec06da59eb4fdfbe8a2..a64ad4038c7c32e903409c6e510c2aa6e8472e1f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -181,6 +181,12 @@ class PlannerSuite extends SharedSQLContext {
     }
   }
 
+  test("terminal limits use CollectLimit") {
+    val query = testData.select('value).limit(2)
+    val planned = query.queryExecution.sparkPlan
+    assert(planned.isInstanceOf[CollectLimit])
+  }
+
   test("PartitioningCollection") {
     withTempTable("normal", "small", "tiny") {
       testData.registerTempTable("normal")
@@ -200,7 +206,7 @@ class PlannerSuite extends SharedSQLContext {
           ).queryExecution.executedPlan.collect {
             case exchange: Exchange => exchange
           }.length
-          assert(numExchanges === 3)
+          assert(numExchanges === 5)
         }
 
         {
@@ -215,7 +221,7 @@ class PlannerSuite extends SharedSQLContext {
           ).queryExecution.executedPlan.collect {
             case exchange: Exchange => exchange
           }.length
-          assert(numExchanges === 3)
+          assert(numExchanges === 5)
         }
 
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index 6259453da26a1396c73ef6b817879b4fece0b1c8..cb6d68dc3ac4675f5b461f7349c6a8aadf5eec15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -56,8 +56,8 @@ class SortSuite extends SparkPlanTest with SharedSQLContext {
   test("sort followed by limit") {
     checkThatPlansAgree(
       (1 to 100).map(v => Tuple1(v)).toDF("a"),
-      (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)),
-      (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
+      (child: SparkPlan) => GlobalLimit(10, Sort('a.asc :: Nil, global = true, child = child)),
+      (child: SparkPlan) => GlobalLimit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
       sortAnswers = false
     )
   }