From 0dd06485c4222a896c0d1ee6a04d30043de3626c Mon Sep 17 00:00:00 2001 From: Yanbo Liang <ybliang8@gmail.com> Date: Wed, 9 Mar 2016 11:59:22 -0800 Subject: [PATCH] [SPARK-13615][ML] GeneralizedLinearRegression supports save/load ## What changes were proposed in this pull request? ```GeneralizedLinearRegression``` supports ```save/load```. cc mengxr ## How was this patch tested? unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #11465 from yanboliang/spark-13615. --- .../GeneralizedLinearRegression.scala | 74 +++++++++++++++++-- .../GeneralizedLinearRegressionSuite.scala | 32 +++++++- 2 files changed, 96 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index a850dfee0a..de1dff9421 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.regression import breeze.stats.distributions.{Gaussian => GD} +import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{Experimental, Since} @@ -26,7 +27,7 @@ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.optim._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} @@ -106,7 +107,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam @Since("2.0.0") class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val uid: String) extends Regressor[Vector, GeneralizedLinearRegression, GeneralizedLinearRegressionModel] - with GeneralizedLinearRegressionBase with Logging { + with GeneralizedLinearRegressionBase with DefaultParamsWritable with Logging { import GeneralizedLinearRegression._ @@ -236,10 +237,13 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } @Since("2.0.0") -private[ml] object GeneralizedLinearRegression { +object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLinearRegression] { + + @Since("2.0.0") + override def load(path: String): GeneralizedLinearRegression = super.load(path) /** Set of family and link pairs that GeneralizedLinearRegression supports. */ - lazy val supportedFamilyAndLinkPairs = Set( + private[ml] lazy val supportedFamilyAndLinkPairs = Set( Gaussian -> Identity, Gaussian -> Log, Gaussian -> Inverse, Binomial -> Logit, Binomial -> Probit, Binomial -> CLogLog, Poisson -> Log, Poisson -> Identity, Poisson -> Sqrt, @@ -247,12 +251,12 @@ private[ml] object GeneralizedLinearRegression { ) /** Set of family names that GeneralizedLinearRegression supports. */ - lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) + private[ml] lazy val supportedFamilyNames = supportedFamilyAndLinkPairs.map(_._1.name) /** Set of link names that GeneralizedLinearRegression supports. */ - lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) + private[ml] lazy val supportedLinkNames = supportedFamilyAndLinkPairs.map(_._2.name) - val epsilon: Double = 1E-16 + private[ml] val epsilon: Double = 1E-16 /** * Wrapper of family and link combination used in the model. @@ -552,7 +556,7 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") val coefficients: Vector, @Since("2.0.0") val intercept: Double) extends RegressionModel[Vector, GeneralizedLinearRegressionModel] - with GeneralizedLinearRegressionBase { + with GeneralizedLinearRegressionBase with MLWritable { import GeneralizedLinearRegression._ @@ -574,4 +578,58 @@ class GeneralizedLinearRegressionModel private[ml] ( copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) .setParent(parent) } + + @Since("2.0.0") + override def write: MLWriter = + new GeneralizedLinearRegressionModel.GeneralizedLinearRegressionModelWriter(this) +} + +@Since("2.0.0") +object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GeneralizedLinearRegressionModel] = + new GeneralizedLinearRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GeneralizedLinearRegressionModel = super.load(path) + + /** [[MLWriter]] instance for [[GeneralizedLinearRegressionModel]] */ + private[GeneralizedLinearRegressionModel] + class GeneralizedLinearRegressionModelWriter(instance: GeneralizedLinearRegressionModel) + extends MLWriter with Logging { + + private case class Data(intercept: Double, coefficients: Vector) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class GeneralizedLinearRegressionModelReader + extends MLReader[GeneralizedLinearRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GeneralizedLinearRegressionModel].getName + + override def load(path: String): GeneralizedLinearRegressionModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val data = sqlContext.read.parquet(dataPath) + .select("intercept", "coefficients").head() + val intercept = data.getDouble(0) + val coefficients = data.getAs[Vector](1) + + val model = new GeneralizedLinearRegressionModel(metadata.uid, coefficients, intercept) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 8bfa9855ce..618304ad19 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{BLAS, DenseVector, Vectors} import org.apache.spark.mllib.random._ @@ -30,7 +30,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { +class GeneralizedLinearRegressionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { private val seed: Int = 42 @transient var datasetGaussianIdentity: DataFrame = _ @@ -464,10 +465,37 @@ class GeneralizedLinearRegressionSuite extends SparkFunSuite with MLlibTestSpark } } } + + test("read/write") { + def checkModelData( + model: GeneralizedLinearRegressionModel, + model2: GeneralizedLinearRegressionModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients.toArray === model2.coefficients.toArray) + } + + val glr = new GeneralizedLinearRegression() + testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) + } } object GeneralizedLinearRegressionSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "family" -> "poisson", + "link" -> "log", + "fitIntercept" -> true, + "maxIter" -> 2, // intentionally small + "tol" -> 0.8, + "regParam" -> 0.01, + "predictionCol" -> "myPrediction") + def generateGeneralizedLinearRegressionInput( intercept: Double, coefficients: Array[Double], -- GitLab