From b928f543845ddd39e914a0e8f0b0205fd86100c5 Mon Sep 17 00:00:00 2001
From: Paavo <pparkkin@gmail.com>
Date: Wed, 10 Jun 2015 23:17:42 +0100
Subject: [PATCH] [SPARK-8200] [MLLIB] Check for empty RDDs in
 StreamingLinearAlgorithm

Test cases for both StreamingLinearRegression and StreamingLogisticRegression, and code fix.

Edit:
This contribution is my original work and I license the work to the project under the project's open source license.

Author: Paavo <pparkkin@gmail.com>

Closes #6713 from pparkkin/streamingmodel-empty-rdd and squashes the following commits:

ff5cd78 [Paavo] Update strings to use interpolation.
db234cf [Paavo] Use !rdd.isEmpty.
54ad89e [Paavo] Test case for empty stream.
393e36f [Paavo] Ignore empty RDDs.
0bfc365 [Paavo] Test case for empty stream.
---
 .../regression/StreamingLinearAlgorithm.scala  | 14 ++++++++------
 .../StreamingLogisticRegressionSuite.scala     | 17 +++++++++++++++++
 .../StreamingLinearRegressionSuite.scala       | 18 ++++++++++++++++++
 3 files changed, 43 insertions(+), 6 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
index aee51bf22d..141052ba81 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala
@@ -83,13 +83,15 @@ abstract class StreamingLinearAlgorithm[
       throw new IllegalArgumentException("Model must be initialized before starting training.")
     }
     data.foreachRDD { (rdd, time) =>
-      model = Some(algorithm.run(rdd, model.get.weights))
-      logInfo("Model updated at time %s".format(time.toString))
-      val display = model.get.weights.size match {
-        case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
-        case _ => model.get.weights.toArray.mkString("[", ",", "]")
+      if (!rdd.isEmpty) {
+        model = Some(algorithm.run(rdd, model.get.weights))
+        logInfo(s"Model updated at time ${time.toString}")
+        val display = model.get.weights.size match {
+          case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...")
+          case _ => model.get.weights.toArray.mkString("[", ",", "]")
+        }
+        logInfo(s"Current model: weights, ${display}")
       }
-      logInfo("Current model: weights, %s".format (display))
     }
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
index e98b61e13e..fd653296c9 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala
@@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase
     val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
     assert(error.head > 0.8 & error.last < 0.2)
   }
+
+  // Test empty RDDs in a stream
+  test("handling empty RDDs in a stream") {
+    val model = new StreamingLogisticRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(-0.1))
+      .setStepSize(0.01)
+      .setNumIterations(10)
+    val numBatches = 10
+    val emptyInput = Seq.empty[Seq[LabeledPoint]]
+    val ssc = setupStreams(emptyInput,
+      (inputDStream: DStream[LabeledPoint]) => {
+        model.trainOn(inputDStream)
+        model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+      }
+    )
+    val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
index 9a379406d5..f5e2d31056 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala
@@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase {
     val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList
     assert((error.head - error.last) > 2)
   }
+
+  // Test empty RDDs in a stream
+  test("handling empty RDDs in a stream") {
+    val model = new StreamingLinearRegressionWithSGD()
+      .setInitialWeights(Vectors.dense(0.0, 0.0))
+      .setStepSize(0.2)
+      .setNumIterations(25)
+    val numBatches = 10
+    val nPoints = 100
+    val emptyInput = Seq.empty[Seq[LabeledPoint]]
+    val ssc = setupStreams(emptyInput,
+      (inputDStream: DStream[LabeledPoint]) => {
+        model.trainOn(inputDStream)
+        model.predictOnValues(inputDStream.map(x => (x.label, x.features)))
+      }
+    )
+    val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches)
+  }
 }
-- 
GitLab