diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index f2001afae4ee58851dd94f66e5afa12aa57b7517..fa4b9c7b688ea3314072578b9481c5bcfc581544 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -323,7 +323,10 @@ class SchemaRDD(RDD):
 
     def count(self):
         """
-        Return the number of elements in this RDD.
+        Return the number of elements in this RDD. Unlike the base RDD
+        implementation of count, this implementation leverages the query
+        optimizer to compute the count on the SchemaRDD, which supports
+        features such as filter pushdown.
 
         >>> srdd = sqlCtx.inferSchema(rdd)
         >>> srdd.count()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 452da3d02310d61005723ffcde27ba39a871c77c..9883ebc0b3c62fc7b29d3a4f3aa61a2df5b3fec4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -276,12 +276,12 @@ class SchemaRDD(
 
   /**
    * :: Experimental ::
-   * Overriding base RDD implementation to leverage query optimizer
+   * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this
+   * implementation leverages the query optimizer to compute the count on the SchemaRDD, which
+   * supports features such as filter pushdown.
    */
   @Experimental
-  override def count(): Long = {
-    groupBy()(Count(Literal(1))).collect().head.getLong(0)
-  }
+  override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0)
 
   /**
    * :: Experimental ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 233132a2fec695955bf4d1e18693161cbc514023..94ba13b14b33d31a782b405324ee5eba5343552c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -124,7 +124,7 @@ class DslQuerySuite extends QueryTest {
   }
 
   test("zero count") {
-    assert(testData4.count() === 0)
+    assert(emptyTableData.count() === 0)
   }
 
   test("inner join where, one match per row") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index b1eecb4dd3be4a89f256f8e91e22cb16524f1314..944f520e43515ea47d8a6c79250fa2d5c47f73e7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -47,7 +47,7 @@ object TestData {
       (1, null) ::
       (2, 2) :: Nil)
 
-  val testData4 = logical.LocalRelation('a.int, 'b.int)
+  val emptyTableData = logical.LocalRelation('a.int, 'b.int)
 
   case class UpperCaseData(N: Int, L: String)
   val upperCaseData =