From d02d5b9295b169c3ebb0967453b2835edb8a121f Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Wed, 18 Nov 2015 21:44:01 -0800
Subject: [PATCH] [SPARK-11842][ML] Small cleanups to existing Readers and
 Writers

Updates:
* Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel.
* Strengthen privacy to class and companion object for Writers and Readers
* Change LogisticRegressionSuite read/write test to fit intercept
* Add Since versions for read/write methods in Pipeline, LogisticRegression
* Switch from hand-written class names in Readers to using getClass

CC: mengxr

CC: yanboliang Would you mind taking a look at this PR?  mengxr might not be able to soon.  Thank you!

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9829 from jkbradley/ml-io-cleanups.
---
 .../scala/org/apache/spark/ml/Pipeline.scala  | 22 +++++++++++++------
 .../classification/LogisticRegression.scala   | 19 ++++++++++------
 .../spark/ml/feature/CountVectorizer.scala    |  2 +-
 .../org/apache/spark/ml/feature/IDF.scala     |  2 +-
 .../spark/ml/feature/MinMaxScaler.scala       |  2 +-
 .../spark/ml/feature/StandardScaler.scala     |  2 +-
 .../spark/ml/feature/StringIndexer.scala      |  2 +-
 .../apache/spark/ml/recommendation/ALS.scala  |  6 ++---
 .../ml/regression/LinearRegression.scala      |  4 ++--
 .../LogisticRegressionSuite.scala             |  2 +-
 10 files changed, 38 insertions(+), 25 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index b0f22e042e..6f15b37abc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -27,7 +27,7 @@ import org.json4s._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.{SparkContext, Logging}
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
 import org.apache.spark.ml.param.{Param, ParamMap, Params}
 import org.apache.spark.ml.util.MLReader
 import org.apache.spark.ml.util.MLWriter
@@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M
     theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur))
   }
 
+  @Since("1.6.0")
   override def write: MLWriter = new Pipeline.PipelineWriter(this)
 }
 
+@Since("1.6.0")
 object Pipeline extends MLReadable[Pipeline] {
 
+  @Since("1.6.0")
   override def read: MLReader[Pipeline] = new PipelineReader
 
+  @Since("1.6.0")
   override def load(path: String): Pipeline = super.load(path)
 
-  private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter {
+  private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter {
 
     SharedReadWrite.validateStages(instance.getStages)
 
@@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] {
       SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
   }
 
-  private[ml] class PipelineReader extends MLReader[Pipeline] {
+  private class PipelineReader extends MLReader[Pipeline] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.Pipeline"
+    private val className = classOf[Pipeline].getName
 
     override def load(path: String): Pipeline = {
       val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
@@ -333,18 +337,22 @@ class PipelineModel private[ml] (
     new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
   }
 
+  @Since("1.6.0")
   override def write: MLWriter = new PipelineModel.PipelineModelWriter(this)
 }
 
+@Since("1.6.0")
 object PipelineModel extends MLReadable[PipelineModel] {
 
   import Pipeline.SharedReadWrite
 
+  @Since("1.6.0")
   override def read: MLReader[PipelineModel] = new PipelineModelReader
 
+  @Since("1.6.0")
   override def load(path: String): PipelineModel = super.load(path)
 
-  private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {
+  private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter {
 
     SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]])
 
@@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] {
       instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
   }
 
-  private[ml] class PipelineModelReader extends MLReader[PipelineModel] {
+  private class PipelineModelReader extends MLReader[PipelineModel] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.PipelineModel"
+    private val className = classOf[PipelineModel].getName
 
     override def load(path: String): PipelineModel = {
       val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a3cc49f7f0..418bbdc9a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.{Logging, SparkException}
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
@@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] (
    *
    * This also does not save the [[parent]] currently.
    */
+  @Since("1.6.0")
   override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
 }
 
 
+@Since("1.6.0")
 object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
 
+  @Since("1.6.0")
   override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader
 
+  @Since("1.6.0")
   override def load(path: String): LogisticRegressionModel = super.load(path)
 
   /** [[MLWriter]] instance for [[LogisticRegressionModel]] */
-  private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
+  private[LogisticRegressionModel]
+  class LogisticRegressionModelWriter(instance: LogisticRegressionModel)
     extends MLWriter with Logging {
 
     private case class Data(
@@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
       val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
         instance.coefficients)
       val dataPath = new Path(path, "data").toString
-      sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
   }
 
-  private[classification] class LogisticRegressionModelReader
+  private class LogisticRegressionModelReader
     extends MLReader[LogisticRegressionModel] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.classification.LogisticRegressionModel"
+    private val className = classOf[LogisticRegressionModel].getName
 
     override def load(path: String): LogisticRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
@@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
    * @return This MultilabelSummarizer
    */
   def add(label: Double, weight: Double = 1.0): this.type = {
-    require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+    require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
 
     if (weight == 0.0) return this
 
@@ -839,7 +844,7 @@ private class LogisticAggregator(
     instance match { case Instance(label, weight, features) =>
       require(dim == features.size, s"Dimensions mismatch when adding new instance." +
         s" Expecting $dim but got ${features.size}.")
-      require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+      require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
 
       if (weight == 0.0) return this
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 4969cf4245..b9e2144c0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] {
 
   private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] {
 
-    private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
+    private val className = classOf[CountVectorizerModel].getName
 
     override def load(path: String): CountVectorizerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 0e00ef6f2e..f7b0f29a27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] {
 
   private class IDFModelReader extends MLReader[IDFModel] {
 
-    private val className = "org.apache.spark.ml.feature.IDFModel"
+    private val className = classOf[IDFModel].getName
 
     override def load(path: String): IDFModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index ed24eabb50..c2866f5ece 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
 
   private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] {
 
-    private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
+    private val className = classOf[MinMaxScalerModel].getName
 
     override def load(path: String): MinMaxScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 1f689c1da1..6d545219eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
 
   private class StandardScalerModelReader extends MLReader[StandardScalerModel] {
 
-    private val className = "org.apache.spark.ml.feature.StandardScalerModel"
+    private val className = classOf[StandardScalerModel].getName
 
     override def load(path: String): StandardScalerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 97a2e4f6d6..5c40c35eea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] {
 
   private class StringIndexerModelReader extends MLReader[StringIndexerModel] {
 
-    private val className = "org.apache.spark.ml.feature.StringIndexerModel"
+    private val className = classOf[StringIndexerModel].getName
 
     override def load(path: String): StringIndexerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 795b73c4c2..4d35177ad9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] {
   @Since("1.6.0")
   override def load(path: String): ALSModel = super.load(path)
 
-  private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter {
+  private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter {
 
     override protected def saveImpl(path: String): Unit = {
       val extraMetadata = render("rank" -> instance.rank)
@@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] {
     }
   }
 
-  private[recommendation] class ALSModelReader extends MLReader[ALSModel] {
+  private class ALSModelReader extends MLReader[ALSModel] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.recommendation.ALSModel"
+    private val className = classOf[ALSModel].getName
 
     override def load(path: String): ALSModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 7ba1a60eda..70ccec766c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] {
       // Save model data: intercept, coefficients
       val data = Data(instance.intercept, instance.coefficients)
       val dataPath = new Path(path, "data").toString
-      sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath)
+      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
   }
 
   private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] {
 
     /** Checked against metadata when loading model */
-    private val className = "org.apache.spark.ml.regression.LinearRegressionModel"
+    private val className = classOf[LinearRegressionModel].getName
 
     override def load(path: String): LinearRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 48ce1bb630..a9a6ff8a78 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -898,7 +898,7 @@ object LogisticRegressionSuite {
     "regParam" -> 0.01,
     "elasticNetParam" -> 0.1,
     "maxIter" -> 2,  // intentionally small
-    "fitIntercept" -> false,
+    "fitIntercept" -> true,
     "tol" -> 0.8,
     "standardization" -> false,
     "threshold" -> 0.6
-- 
GitLab