diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f2001afae4ee58851dd94f66e5afa12aa57b7517..fa4b9c7b688ea3314072578b9481c5bcfc581544 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -323,7 +323,10 @@ class SchemaRDD(RDD): def count(self): """ - Return the number of elements in this RDD. + Return the number of elements in this RDD. Unlike the base RDD + implementation of count, this implementation leverages the query + optimizer to compute the count on the SchemaRDD, which supports + features such as filter pushdown. >>> srdd = sqlCtx.inferSchema(rdd) >>> srdd.count() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 452da3d02310d61005723ffcde27ba39a871c77c..9883ebc0b3c62fc7b29d3a4f3aa61a2df5b3fec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -276,12 +276,12 @@ class SchemaRDD( /** * :: Experimental :: - * Overriding base RDD implementation to leverage query optimizer + * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this + * implementation leverages the query optimizer to compute the count on the SchemaRDD, which + * supports features such as filter pushdown. */ @Experimental - override def count(): Long = { - groupBy()(Count(Literal(1))).collect().head.getLong(0) - } + override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0) /** * :: Experimental :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 233132a2fec695955bf4d1e18693161cbc514023..94ba13b14b33d31a782b405324ee5eba5343552c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -124,7 +124,7 @@ class DslQuerySuite extends QueryTest { } test("zero count") { - assert(testData4.count() === 0) + assert(emptyTableData.count() === 0) } test("inner join where, one match per row") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index b1eecb4dd3be4a89f256f8e91e22cb16524f1314..944f520e43515ea47d8a6c79250fa2d5c47f73e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -47,7 +47,7 @@ object TestData { (1, null) :: (2, 2) :: Nil) - val testData4 = logical.LocalRelation('a.int, 'b.int) + val emptyTableData = logical.LocalRelation('a.int, 'b.int) case class UpperCaseData(N: Int, L: String) val upperCaseData =