From d66642e3978a76977414c2fdaedebaad35662667 Mon Sep 17 00:00:00 2001
From: Reynold Xin <rxin@apache.org>
Date: Sun, 25 May 2014 01:44:49 -0700
Subject: [PATCH] SPARK-1822: Some minor cleanup work on SchemaRDD.count()

Minor cleanup following #841.

Author: Reynold Xin <rxin@apache.org>

Closes #868 from rxin/schema-count and squashes the following commits:

5442651 [Reynold Xin] SPARK-1822: Some minor cleanup work on SchemaRDD.count()
---
 python/pyspark/sql.py                                     | 5 ++++-
 .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala   | 8 ++++----
 .../test/scala/org/apache/spark/sql/DslQuerySuite.scala   | 2 +-
 .../src/test/scala/org/apache/spark/sql/TestData.scala    | 2 +-
 4 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index f2001afae4..fa4b9c7b68 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -323,7 +323,10 @@ class SchemaRDD(RDD):
 
     def count(self):
         """
-        Return the number of elements in this RDD.
+        Return the number of elements in this RDD. Unlike the base RDD
+        implementation of count, this implementation leverages the query
+        optimizer to compute the count on the SchemaRDD, which supports
+        features such as filter pushdown.
 
         >>> srdd = sqlCtx.inferSchema(rdd)
         >>> srdd.count()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 452da3d023..9883ebc0b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -276,12 +276,12 @@ class SchemaRDD(
 
   /**
    * :: Experimental ::
-   * Overriding base RDD implementation to leverage query optimizer
+   * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this
+   * implementation leverages the query optimizer to compute the count on the SchemaRDD, which
+   * supports features such as filter pushdown.
    */
   @Experimental
-  override def count(): Long = {
-    groupBy()(Count(Literal(1))).collect().head.getLong(0)
-  }
+  override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0)
 
   /**
    * :: Experimental ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 233132a2fe..94ba13b14b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -124,7 +124,7 @@ class DslQuerySuite extends QueryTest {
   }
 
   test("zero count") {
-    assert(testData4.count() === 0)
+    assert(emptyTableData.count() === 0)
   }
 
   test("inner join where, one match per row") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index b1eecb4dd3..944f520e43 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -47,7 +47,7 @@ object TestData {
       (1, null) ::
       (2, 2) :: Nil)
 
-  val testData4 = logical.LocalRelation('a.int, 'b.int)
+  val emptyTableData = logical.LocalRelation('a.int, 'b.int)
 
   case class UpperCaseData(N: Int, L: String)
   val upperCaseData =
-- 
GitLab