diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index bf7c47b7261a9e8944ad433fe4407b9b54f6d002..d51309f7ef5aaa600f2422bedf9a771645b31425 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -520,6 +520,25 @@ class DataFrame(object):
 
     orderBy = sort
 
+    def describe(self, *cols):
+        """Computes statistics for numeric columns.
+
+        This include count, mean, stddev, min, and max. If no columns are
+        given, this function computes statistics for all numerical columns.
+
+        >>> df.describe().show()
+        summary age
+        count   2
+        mean    3.5
+        stddev  1.5
+        min     2
+        max     5
+        """
+        cols = ListConverter().convert(cols,
+                                       self.sql_ctx._sc._gateway._gateway_client)
+        jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
+        return DataFrame(jdf, self.sql_ctx)
+
     def head(self, n=None):
         """ Return the first `n` rows or the first row if n is None.
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index db561825e676b2108d2ffefe618f3c1a3d7075c6..4c80359cf07af2acd87586fcc4db4c98457b388d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -33,7 +33,7 @@ import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.python.SerDeUtil
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{expressions, ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
 import org.apache.spark.sql.jdbc.JDBCWriteDetails
 import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.types.{NumericType, StructType, StructField, StringType}
+import org.apache.spark.sql.types._
 import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
 import org.apache.spark.util.Utils
 
@@ -752,15 +752,17 @@ class DataFrame private[sql](
   }
 
   /**
-   * Compute numerical statistics for given columns of this [[DataFrame]]:
-   * count, mean (avg), stddev (standard deviation), min, max.
-   * Each row of the resulting [[DataFrame]] contains column with statistic name
-   * and columns with statistic results for each given column.
-   * If no columns are given then computes for all numerical columns.
+   * Computes statistics for numeric columns, including count, mean, stddev, min, and max.
+   * If no columns are given, this function computes statistics for all numerical columns.
+   *
+   * This function is meant for exploratory data analysis, as we make no guarantee about the
+   * backward compatibility of the schema of the resulting [[DataFrame]]. If you want to
+   * programmatically compute summary statistics, use the `agg` function instead.
    *
    * {{{
-   *   df.describe("age", "height")
+   *   df.describe("age", "height").show()
    *
+   *   // output:
    *   // summary age   height
    *   // count   10.0  10.0
    *   // mean    53.3  178.05
@@ -768,13 +770,17 @@ class DataFrame private[sql](
    *   // min     18.0  163.0
    *   // max     92.0  192.0
    * }}}
+   *
+   * @group action
    */
   @scala.annotation.varargs
   def describe(cols: String*): DataFrame = {
 
-    def stddevExpr(expr: Expression) =
+    // TODO: Add stddev as an expression, and remove it from here.
+    def stddevExpr(expr: Expression): Expression =
       Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
 
+    // The list of summary statistics to compute, in the form of expressions.
     val statistics = List[(String, Expression => Expression)](
       "count" -> Count,
       "mean" -> Average,
@@ -782,24 +788,28 @@ class DataFrame private[sql](
       "min" -> Min,
       "max" -> Max)
 
-    val aggCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
+    val outputCols = (if (cols.isEmpty) numericColumns.map(_.prettyString) else cols).toList
 
-    val localAgg = if (aggCols.nonEmpty) {
+    val ret: Seq[Row] = if (outputCols.nonEmpty) {
       val aggExprs = statistics.flatMap { case (_, colToAgg) =>
-        aggCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
+        outputCols.map(c => Column(colToAgg(Column(c).expr)).as(c))
       }
 
-      agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
-        .grouped(aggCols.size).toSeq.zip(statistics).map { case (aggregation, (statistic, _)) =>
-        Row(statistic :: aggregation.toList: _*)
+      val row = agg(aggExprs.head, aggExprs.tail: _*).head().toSeq
+
+      // Pivot the data so each summary is one row
+      row.grouped(outputCols.size).toSeq.zip(statistics).map {
+        case (aggregation, (statistic, _)) => Row(statistic :: aggregation.toList: _*)
       }
     } else {
+      // If there are no output columns, just output a single column that contains the stats.
       statistics.map { case (name, _) => Row(name) }
     }
 
-    val schema = StructType(("summary" :: aggCols).map(StructField(_, StringType)))
-    val rowRdd = sqlContext.sparkContext.parallelize(localAgg)
-    sqlContext.createDataFrame(rowRdd, schema)
+    // The first column is string type, and the rest are double type.
+    val schema = StructType(
+      StructField("summary", StringType) :: outputCols.map(StructField(_, DoubleType))).toAttributes
+    LocalRelation(schema, ret)
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index afbedd1e5825d57fd1d3b88583ba102f4ebd39a3..fbc4065a9666ce9b4581ffcf6c79ca555675ab74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -444,7 +444,6 @@ class DataFrameSuite extends QueryTest {
   }
 
   test("describe") {
-
     val describeTestData = Seq(
       ("Bob",   16, 176),
       ("Alice", 32, 164),
@@ -465,7 +464,7 @@ class DataFrameSuite extends QueryTest {
       Row("min",     null, null),
       Row("max",     null, null))
 
-    def getSchemaAsSeq(df: DataFrame) = df.schema.map(_.name).toSeq
+    def getSchemaAsSeq(df: DataFrame): Seq[String] = df.schema.map(_.name)
 
     val describeTwoCols = describeTestData.describe("age", "height")
     assert(getSchemaAsSeq(describeTwoCols) === Seq("summary", "age", "height"))