diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index aee51bf22d8d09731c308c9fc21a8703363d2304..141052ba813ee86b29152b7549bee7b70ea67088 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -83,13 +83,15 @@ abstract class StreamingLinearAlgorithm[ throw new IllegalArgumentException("Model must be initialized before starting training.") } data.foreachRDD { (rdd, time) => - model = Some(algorithm.run(rdd, model.get.weights)) - logInfo("Model updated at time %s".format(time.toString)) - val display = model.get.weights.size match { - case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") - case _ => model.get.weights.toArray.mkString("[", ",", "]") + if (!rdd.isEmpty) { + model = Some(algorithm.run(rdd, model.get.weights)) + logInfo(s"Model updated at time ${time.toString}") + val display = model.get.weights.size match { + case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") + case _ => model.get.weights.toArray.mkString("[", ",", "]") + } + logInfo(s"Current model: weights, ${display}") } - logInfo("Current model: weights, %s".format (display)) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index e98b61e13e21f699fac5ce157db3b2c78e18afda..fd653296c9d9735c1da1ba5c893029af44249fe1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -158,4 +158,21 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert(error.head > 0.8 & error.last < 0.2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLogisticRegressionWithSGD() + .setInitialWeights(Vectors.dense(-0.1)) + .setStepSize(0.01) + .setNumIterations(10) + val numBatches = 10 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 9a379406d5061fbbd873f41d7d8e576e942fb903..f5e2d31056cbd7542fdc1fafdca96ab322f3e4cc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -166,4 +166,22 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { val error = output.map(batch => batch.map(p => math.abs(p._1 - p._2)).sum / nPoints).toList assert((error.head - error.last) > 2) } + + // Test empty RDDs in a stream + test("handling empty RDDs in a stream") { + val model = new StreamingLinearRegressionWithSGD() + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.2) + .setNumIterations(25) + val numBatches = 10 + val nPoints = 100 + val emptyInput = Seq.empty[Seq[LabeledPoint]] + val ssc = setupStreams(emptyInput, + (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + model.predictOnValues(inputDStream.map(x => (x.label, x.features))) + } + ) + val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) + } }