Skip to content
Snippets Groups Projects
Commit 2cf46d5a authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-11871] Add save/load for MLPC

## What changes were proposed in this pull request?

https://issues.apache.org/jira/browse/SPARK-11871

Add save/load for MLPC

## How was this patch tested?

Test with Scala unit test

Author: Xusen Yin <yinxusen@gmail.com>

Closes #9854 from yinxusen/SPARK-11871.
parent d283223a
No related branches found
No related tags found
No related merge requests found
......@@ -19,12 +19,14 @@ package org.apache.spark.ml.classification
import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.DataFrame
......@@ -110,7 +112,7 @@ private object LabelConverter {
class MultilayerPerceptronClassifier @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
with MultilayerPerceptronParams {
with MultilayerPerceptronParams with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mlpc"))
......@@ -172,6 +174,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
}
}
@Since("2.0.0")
object MultilayerPerceptronClassifier
extends DefaultParamsReadable[MultilayerPerceptronClassifier] {
@Since("2.0.0")
override def load(path: String): MultilayerPerceptronClassifier = super.load(path)
}
/**
* :: Experimental ::
* Classification model based on the Multilayer Perceptron.
......@@ -188,7 +198,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.5.0") val layers: Array[Int],
@Since("1.5.0") val weights: Vector)
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
with Serializable {
with Serializable with MLWritable {
@Since("1.6.0")
override val numFeatures: Int = layers.head
......@@ -214,4 +224,57 @@ class MultilayerPerceptronClassificationModel private[ml] (
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
}
@Since("2.0.0")
override def write: MLWriter =
new MultilayerPerceptronClassificationModel.MultilayerPerceptronClassificationModelWriter(this)
}
@Since("2.0.0")
object MultilayerPerceptronClassificationModel
extends MLReadable[MultilayerPerceptronClassificationModel] {
@Since("2.0.0")
override def read: MLReader[MultilayerPerceptronClassificationModel] =
new MultilayerPerceptronClassificationModelReader
@Since("2.0.0")
override def load(path: String): MultilayerPerceptronClassificationModel = super.load(path)
/** [[MLWriter]] instance for [[MultilayerPerceptronClassificationModel]] */
private[MultilayerPerceptronClassificationModel]
class MultilayerPerceptronClassificationModelWriter(
instance: MultilayerPerceptronClassificationModel) extends MLWriter {
private case class Data(layers: Array[Int], weights: Vector)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: layers, weights
val data = Data(instance.layers, instance.weights)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class MultilayerPerceptronClassificationModelReader
extends MLReader[MultilayerPerceptronClassificationModel] {
/** Checked against metadata when loading model */
private val className = classOf[MultilayerPerceptronClassificationModel].getName
override def load(path: String): MultilayerPerceptronClassificationModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("layers", "weights").head()
val layers = data.getAs[Seq[Int]](0).toArray
val weights = data.getAs[Vector](1)
val model = new MultilayerPerceptronClassificationModel(metadata.uid, layers, weights)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
......@@ -18,31 +18,40 @@
package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
import org.apache.spark.sql.{DataFrame, Row}
class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
class MultilayerPerceptronClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("XOR function learning as binary classification problem with two outputs.") {
val dataFrame = sqlContext.createDataFrame(Seq(
@transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
dataset = sqlContext.createDataFrame(Seq(
(Vectors.dense(0.0, 0.0), 0.0),
(Vectors.dense(0.0, 1.0), 1.0),
(Vectors.dense(1.0, 0.0), 1.0),
(Vectors.dense(1.0, 1.0), 0.0))
).toDF("features", "label")
}
test("XOR function learning as binary classification problem with two outputs.") {
val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
.setSeed(11L)
.setMaxIter(100)
val model = trainer.fit(dataFrame)
val result = model.transform(dataFrame)
val model = trainer.fit(dataset)
val result = model.transform(dataset)
val predictionAndLabels = result.select("prediction", "label").collect()
predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
assert(p == l)
......@@ -92,4 +101,26 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
val mlpMetrics = new MulticlassMetrics(mlpPredictionAndLabels)
assert(mlpMetrics.confusionMatrix ~== lrMetrics.confusionMatrix absTol 100)
}
test("read/write: MultilayerPerceptronClassifier") {
val mlp = new MultilayerPerceptronClassifier()
.setLayers(Array(2, 3, 2))
.setMaxIter(5)
.setBlockSize(2)
.setSeed(42)
.setTol(0.1)
.setFeaturesCol("myFeatures")
.setLabelCol("myLabel")
.setPredictionCol("myPrediction")
testDefaultReadWrite(mlp, testParams = true)
}
test("read/write: MultilayerPerceptronClassificationModel") {
val mlp = new MultilayerPerceptronClassifier().setLayers(Array(2, 3, 2)).setMaxIter(5)
val mlpModel = mlp.fit(dataset)
val newMlpModel = testDefaultReadWrite(mlpModel, testParams = true)
assert(newMlpModel.layers === mlpModel.layers)
assert(newMlpModel.weights === mlpModel.weights)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment