From d02d5b9295b169c3ebb0967453b2835edb8a121f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" <joseph@databricks.com> Date: Wed, 18 Nov 2015 21:44:01 -0800 Subject: [PATCH] [SPARK-11842][ML] Small cleanups to existing Readers and Writers Updates: * Add repartition(1) to save() methods' saving of data for LogisticRegressionModel, LinearRegressionModel. * Strengthen privacy to class and companion object for Writers and Readers * Change LogisticRegressionSuite read/write test to fit intercept * Add Since versions for read/write methods in Pipeline, LogisticRegression * Switch from hand-written class names in Readers to using getClass CC: mengxr CC: yanboliang Would you mind taking a look at this PR? mengxr might not be able to soon. Thank you! Author: Joseph K. Bradley <joseph@databricks.com> Closes #9829 from jkbradley/ml-io-cleanups. --- .../scala/org/apache/spark/ml/Pipeline.scala | 22 +++++++++++++------ .../classification/LogisticRegression.scala | 19 ++++++++++------ .../spark/ml/feature/CountVectorizer.scala | 2 +- .../org/apache/spark/ml/feature/IDF.scala | 2 +- .../spark/ml/feature/MinMaxScaler.scala | 2 +- .../spark/ml/feature/StandardScaler.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- .../apache/spark/ml/recommendation/ALS.scala | 6 ++--- .../ml/regression/LinearRegression.scala | 4 ++-- .../LogisticRegressionSuite.scala | 2 +- 10 files changed, 38 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index b0f22e042e..6f15b37abc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -27,7 +27,7 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util.MLReader import org.apache.spark.ml.util.MLWriter @@ -174,16 +174,20 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur)) } + @Since("1.6.0") override def write: MLWriter = new Pipeline.PipelineWriter(this) } +@Since("1.6.0") object Pipeline extends MLReadable[Pipeline] { + @Since("1.6.0") override def read: MLReader[Pipeline] = new PipelineReader + @Since("1.6.0") override def load(path: String): Pipeline = super.load(path) - private[ml] class PipelineWriter(instance: Pipeline) extends MLWriter { + private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) @@ -191,10 +195,10 @@ object Pipeline extends MLReadable[Pipeline] { SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } - private[ml] class PipelineReader extends MLReader[Pipeline] { + private class PipelineReader extends MLReader[Pipeline] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.Pipeline" + private val className = classOf[Pipeline].getName override def load(path: String): Pipeline = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) @@ -333,18 +337,22 @@ class PipelineModel private[ml] ( new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } + @Since("1.6.0") override def write: MLWriter = new PipelineModel.PipelineModelWriter(this) } +@Since("1.6.0") object PipelineModel extends MLReadable[PipelineModel] { import Pipeline.SharedReadWrite + @Since("1.6.0") override def read: MLReader[PipelineModel] = new PipelineModelReader + @Since("1.6.0") override def load(path: String): PipelineModel = super.load(path) - private[ml] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) @@ -352,10 +360,10 @@ object PipelineModel extends MLReadable[PipelineModel] { instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } - private[ml] class PipelineModelReader extends MLReader[PipelineModel] { + private class PipelineModelReader extends MLReader[PipelineModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.PipelineModel" + private val className = classOf[PipelineModel].getName override def load(path: String): PipelineModel = { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index a3cc49f7f0..418bbdc9a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -525,18 +525,23 @@ class LogisticRegressionModel private[ml] ( * * This also does not save the [[parent]] currently. */ + @Since("1.6.0") override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this) } +@Since("1.6.0") object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { + @Since("1.6.0") override def read: MLReader[LogisticRegressionModel] = new LogisticRegressionModelReader + @Since("1.6.0") override def load(path: String): LogisticRegressionModel = super.load(path) /** [[MLWriter]] instance for [[LogisticRegressionModel]] */ - private[classification] class LogisticRegressionModelWriter(instance: LogisticRegressionModel) + private[LogisticRegressionModel] + class LogisticRegressionModelWriter(instance: LogisticRegressionModel) extends MLWriter with Logging { private case class Data( @@ -552,15 +557,15 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] { val data = Data(instance.numClasses, instance.numFeatures, instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } - private[classification] class LogisticRegressionModelReader + private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.classification.LogisticRegressionModel" + private val className = classOf[LogisticRegressionModel].getName override def load(path: String): LogisticRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) @@ -603,7 +608,7 @@ private[classification] class MultiClassSummarizer extends Serializable { * @return This MultilabelSummarizer */ def add(label: Double, weight: Double = 1.0): this.type = { - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -839,7 +844,7 @@ private class LogisticAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 4969cf4245..b9e2144c0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -266,7 +266,7 @@ object CountVectorizerModel extends MLReadable[CountVectorizerModel] { private class CountVectorizerModelReader extends MLReader[CountVectorizerModel] { - private val className = "org.apache.spark.ml.feature.CountVectorizerModel" + private val className = classOf[CountVectorizerModel].getName override def load(path: String): CountVectorizerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index 0e00ef6f2e..f7b0f29a27 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -155,7 +155,7 @@ object IDFModel extends MLReadable[IDFModel] { private class IDFModelReader extends MLReader[IDFModel] { - private val className = "org.apache.spark.ml.feature.IDFModel" + private val className = classOf[IDFModel].getName override def load(path: String): IDFModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index ed24eabb50..c2866f5ece 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -210,7 +210,7 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] { private class MinMaxScalerModelReader extends MLReader[MinMaxScalerModel] { - private val className = "org.apache.spark.ml.feature.MinMaxScalerModel" + private val className = classOf[MinMaxScalerModel].getName override def load(path: String): MinMaxScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 1f689c1da1..6d545219eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -180,7 +180,7 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] { private class StandardScalerModelReader extends MLReader[StandardScalerModel] { - private val className = "org.apache.spark.ml.feature.StandardScalerModel" + private val className = classOf[StandardScalerModel].getName override def load(path: String): StandardScalerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 97a2e4f6d6..5c40c35eea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -210,7 +210,7 @@ object StringIndexerModel extends MLReadable[StringIndexerModel] { private class StringIndexerModelReader extends MLReader[StringIndexerModel] { - private val className = "org.apache.spark.ml.feature.StringIndexerModel" + private val className = classOf[StringIndexerModel].getName override def load(path: String): StringIndexerModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 795b73c4c2..4d35177ad9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -237,7 +237,7 @@ object ALSModel extends MLReadable[ALSModel] { @Since("1.6.0") override def load(path: String): ALSModel = super.load(path) - private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter { + private[ALSModel] class ALSModelWriter(instance: ALSModel) extends MLWriter { override protected def saveImpl(path: String): Unit = { val extraMetadata = render("rank" -> instance.rank) @@ -249,10 +249,10 @@ object ALSModel extends MLReadable[ALSModel] { } } - private[recommendation] class ALSModelReader extends MLReader[ALSModel] { + private class ALSModelReader extends MLReader[ALSModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.recommendation.ALSModel" + private val className = classOf[ALSModel].getName override def load(path: String): ALSModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 7ba1a60eda..70ccec766c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -467,14 +467,14 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { // Save model data: intercept, coefficients val data = Data(instance.intercept, instance.coefficients) val dataPath = new Path(path, "data").toString - sqlContext.createDataFrame(Seq(data)).write.format("parquet").save(dataPath) + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) } } private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ - private val className = "org.apache.spark.ml.regression.LinearRegressionModel" + private val className = classOf[LinearRegressionModel].getName override def load(path: String): LinearRegressionModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48ce1bb630..a9a6ff8a78 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -898,7 +898,7 @@ object LogisticRegressionSuite { "regParam" -> 0.01, "elasticNetParam" -> 0.1, "maxIter" -> 2, // intentionally small - "fitIntercept" -> false, + "fitIntercept" -> true, "tol" -> 0.8, "standardization" -> false, "threshold" -> 0.6 -- GitLab