Skip to content
Snippets Groups Projects
Commit 2cf46d5a authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-11871] Add save/load for MLPC

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-11871

Add save/load for MLPC

## How was this patch tested?

Test with Scala unit test

Author: Xusen Yin <yinxusen@gmail.com>

Closes #9854 from yinxusen/SPARK-11871.
parent d283223a
No related branches found
No related tags found
No related merge requests found
...@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification ...@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol} import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
...@@ -110,7 +112,7 @@ private object LabelConverter { ...@@ -110,7 +112,7 @@ private object LabelConverter {
class MultilayerPerceptronClassifier @Since("1.5.0") ( class MultilayerPerceptronClassifier @Since("1.5.0") (
@Since("1.5.0") override val uid: String) @Since("1.5.0") override val uid: String)
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
with MultilayerPerceptronParams { with MultilayerPerceptronParams with DefaultParamsWritable {
@Since("1.5.0") @Since("1.5.0")
def this() = this(Identifiable.randomUID("mlpc")) def this() = this(Identifiable.randomUID("mlpc"))
...@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( ...@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
} }
} }
@Since("2.0.0")
object MultilayerPerceptronClassifier
extends DefaultParamsReadable[MultilayerPerceptronClassifier] {
@Since("2.0.0")
override def load(path: String): MultilayerPerceptronClassifier = super.load(path)
}
/** /**
* :: Experimental :: * :: Experimental ::
* Classification model based on the Multilayer Perceptron. * Classification model based on the Multilayer Perceptron.
...@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( ...@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.5.0") val layers: Array[Int], @Since("1.5.0") val layers: Array[Int],
@Since("1.5.0") val weights: Vector) @Since("1.5.0") val weights: Vector)
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
with Serializable { with Serializable with MLWritable {
@Since("1.6.0") @Since("1.6.0")
override val numFeatures: Int = layers.head override val numFeatures: Int = layers.head
...@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] ( ...@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] (
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
} }
@Since("2.0.0")
override def write: MLWriter =
new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)
}
@Since("2.0.0")
object MultilayerPerceptronClassificationModel
extends MLReadable[MultilayerPerceptronClassificationModel] {
@Since("2.0.0")
override def read: MLReader[MultilayerPerceptronClassificationModel] =
new MultilayerPerceptronClassificationModelReader
@Since("2.0.0")
override def load(path: String): MultilayerPerceptronClassificationModel = super.load(path)
/** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */
private[MultilayerPerceptronClassificationModel]
class MultilayerPerceptronClassificationModelWriter(
instance: MultilayerPerceptronClassificationModel) extends MLWriter {
private case class Data(layers: Array[Int], weights: Vector)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: layers, weights
val data = Data(instance.layers, instance.weights)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class MultilayerPerceptronClassificationModelReader
extends MLReader[MultilayerPerceptronClassificationModel] {
/** Checked against metadata when loading model */
private val className = classOf[MultilayerPerceptronClassificationModel].getName
override def load(path: String): MultilayerPerceptronClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
val layers = data.getAs[Seq[Int]](0).toArray
val weights = data.getAs[Vector](1)
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
} }
...@@ -18,31 +18,40 @@ ...@@ -18,31 +18,40 @@
package org.apache.spark.ml.classification package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row import org.apache.spark.sql.{DataFrame, Row}
class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { class MultilayerPerceptronClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("XOR function learning as binary classification problem with two outputs.") { @transient var dataset: DataFrame = _
val dataFrame = sqlContext.createDataFrame(Seq(
override def beforeAll(): Unit = {
super.beforeAll()
dataset = sqlContext.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0), (Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0), (Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0), (Vectors.dense(1.0, 0.0), 1.0),
(Vectors.dense(1.0, 1.0), 0.0)) (Vectors.dense(1.0, 1.0), 0.0))
).toDF("features", "label") ).toDF("features", "label")
}
test("XOR function learning as binary classification problem with two outputs.") {
val layers = Array[Int](2, 5, 2) val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier() val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers) .setLayers(layers)
.setBlockSize(1) .setBlockSize(1)
.setSeed(11L) .setSeed(11L)
.setMaxIter(100) .setMaxIter(100)
val model = trainer.fit(dataFrame) val model = trainer.fit(dataset)
val result = model.transform(dataFrame) val result = model.transform(dataset)
val predictionAndLabels = result.select("prediction", "label").collect() val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) => predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
assert(p == l) assert(p == l)
...@@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp ...@@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels) val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100) assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
} }
test("read/write: MultilayerPerceptronClassifier") {
val mlp = new MultilayerPerceptronClassifier()
.setLayers(Array(2, 3, 2))
.setMaxIter(5)
.setBlockSize(2)
.setSeed(42)
.setTol(0.1)
.setFeaturesCol("myFeatures")
.setLabelCol("myLabel")
.setPredictionCol("myPrediction")
testDefaultReadWrite(mlp, testParams = true)
}
test("read/write: MultilayerPerceptronClassificationModel") {
val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 2)).setMaxIter(5)
val mlpModel = mlp.fit(dataset)
val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true)
assert(newMlpModel.layers === mlpModel.layers)
assert(newMlpModel.weights === mlpModel.weights)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment