diff --git a/streaming/src/main/scala/spark/streaming/DStream.scala b/streaming/src/main/scala/spark/streaming/DStream.scala
index e1be5ef51cc9cd8cc4c71c52b6f8dcb820dbb761..e125310861ebcd016dabbd3f25e29f9d95b1fbb7 100644
--- a/streaming/src/main/scala/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/spark/streaming/DStream.scala
@@ -441,7 +441,7 @@ abstract class DStream[T: ClassManifest] (
    * Return a new DStream in which each RDD has a single element generated by counting each RDD
    * of this DStream.
    */
-  def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _)
+  def count(): DStream[Long] = this.map(_ => (null, 1L)).transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1))).reduceByKey(_ + _).map(_._2)
 
   /**
    * Return a new DStream in which each RDD contains the counts of each distinct value in
diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
index 8fce91853c77eed0f5d3e5e48aa2e93f13988e3e..168e1b7a557c963a98a81683ffe88d089ca7391a 100644
--- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala
@@ -90,9 +90,9 @@ class BasicOperationsSuite extends TestSuiteBase {
 
   test("count") {
     testOperation(
-      Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4),
+      Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4),
       (s: DStream[Int]) => s.count(),
-      Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L))
+      Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L))
     )
   }