Skip to content
Snippets Groups Projects
Commit 3e1d120c authored by Xusen Yin's avatar Xusen Yin Committed by Xiangrui Meng
Browse files

[SPARK-11867] Add save/load for kmeans and naive bayes

https://issues.apache.org/jira/browse/SPARK-11867

Author: Xusen Yin <yinxusen@gmail.com>

Closes #9849 from yinxusen/SPARK-11867.
parent 0fff8eb3
No related branches found
No related tags found
No related merge requests found
...@@ -17,12 +17,15 @@ ...@@ -17,12 +17,15 @@
package org.apache.spark.ml.classification package org.apache.spark.ml.classification
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util._
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes, NaiveBayesModel => OldNaiveBayesModel} import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
...@@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams { ...@@ -72,7 +75,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
@Experimental @Experimental
class NaiveBayes(override val uid: String) class NaiveBayes(override val uid: String)
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams { with NaiveBayesParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("nb")) def this() = this(Identifiable.randomUID("nb"))
...@@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String) ...@@ -102,6 +105,13 @@ class NaiveBayes(override val uid: String)
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
} }
@Since("1.6.0")
object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
@Since("1.6.0")
override def load(path: String): NaiveBayes = super.load(path)
}
/** /**
* :: Experimental :: * :: Experimental ::
* Model produced by [[NaiveBayes]] * Model produced by [[NaiveBayes]]
...@@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] ( ...@@ -114,7 +124,8 @@ class NaiveBayesModel private[ml] (
override val uid: String, override val uid: String,
val pi: Vector, val pi: Vector,
val theta: Matrix) val theta: Matrix)
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams { extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
with NaiveBayesParams with MLWritable {
import OldNaiveBayes.{Bernoulli, Multinomial} import OldNaiveBayes.{Bernoulli, Multinomial}
...@@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] ( ...@@ -203,12 +214,15 @@ class NaiveBayesModel private[ml] (
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
} }
@Since("1.6.0")
override def write: MLWriter = new NaiveBayesModel.NaiveBayesModelWriter(this)
} }
private[ml] object NaiveBayesModel { @Since("1.6.0")
object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
/** Convert a model from the old API */ /** Convert a model from the old API */
def fromOld( private[ml] def fromOld(
oldModel: OldNaiveBayesModel, oldModel: OldNaiveBayesModel,
parent: NaiveBayes): NaiveBayesModel = { parent: NaiveBayes): NaiveBayesModel = {
val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb") val uid = if (parent != null) parent.uid else Identifiable.randomUID("nb")
...@@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel { ...@@ -218,4 +232,44 @@ private[ml] object NaiveBayesModel {
oldModel.theta.flatten, true) oldModel.theta.flatten, true)
new NaiveBayesModel(uid, pi, theta) new NaiveBayesModel(uid, pi, theta)
} }
@Since("1.6.0")
override def read: MLReader[NaiveBayesModel] = new NaiveBayesModelReader
@Since("1.6.0")
override def load(path: String): NaiveBayesModel = super.load(path)
/** [[MLWriter]] instance for [[NaiveBayesModel]] */
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
private case class Data(pi: Vector, theta: Matrix)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: pi, theta
val data = Data(instance.pi, instance.theta)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
/** Checked against metadata when loading model */
private val className = classOf[NaiveBayesModel].getName
override def load(path: String): NaiveBayesModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("pi", "theta").head()
val pi = data.getAs[Vector](0)
val theta = data.getAs[Matrix](1)
val model = new NaiveBayesModel(metadata.uid, pi, theta)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
} }
...@@ -17,10 +17,12 @@ ...@@ -17,10 +17,12 @@
package org.apache.spark.ml.clustering package org.apache.spark.ml.clustering
import org.apache.spark.annotation.{Since, Experimental} import org.apache.hadoop.fs.Path
import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
...@@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf} ...@@ -28,7 +30,6 @@ import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
/** /**
* Common params for KMeans and KMeansModel * Common params for KMeans and KMeansModel
*/ */
...@@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe ...@@ -94,7 +95,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Experimental @Experimental
class KMeansModel private[ml] ( class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String, @Since("1.5.0") override val uid: String,
private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams { private val parentModel: MLlibKMeansModel)
extends Model[KMeansModel] with KMeansParams with MLWritable {
@Since("1.5.0") @Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = { override def copy(extra: ParamMap): KMeansModel = {
...@@ -129,6 +131,52 @@ class KMeansModel private[ml] ( ...@@ -129,6 +131,52 @@ class KMeansModel private[ml] (
val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point } val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
parentModel.computeCost(data) parentModel.computeCost(data)
} }
@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
}
@Since("1.6.0")
object KMeansModel extends MLReadable[KMeansModel] {
@Since("1.6.0")
override def read: MLReader[KMeansModel] = new KMeansModelReader
@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)
/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
private case class Data(clusterCenters: Array[Vector])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
val data = Data(instance.clusterCenters)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
}
private class KMeansModelReader extends MLReader[KMeansModel] {
/** Checked against metadata when loading model */
private val className = classOf[KMeansModel].getName
override def load(path: String): KMeansModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
val clusterCenters = data.getAs[Seq[Vector]](0).toArray
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
} }
/** /**
...@@ -141,7 +189,7 @@ class KMeansModel private[ml] ( ...@@ -141,7 +189,7 @@ class KMeansModel private[ml] (
@Experimental @Experimental
class KMeans @Since("1.5.0") ( class KMeans @Since("1.5.0") (
@Since("1.5.0") override val uid: String) @Since("1.5.0") override val uid: String)
extends Estimator[KMeansModel] with KMeansParams { extends Estimator[KMeansModel] with KMeansParams with DefaultParamsWritable {
setDefault( setDefault(
k -> 2, k -> 2,
...@@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") ( ...@@ -210,3 +258,10 @@ class KMeans @Since("1.5.0") (
} }
} }
@Since("1.6.0")
object KMeans extends DefaultParamsReadable[KMeans] {
@Since("1.6.0")
override def load(path: String): KMeans = super.load(path)
}
...@@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV} ...@@ -21,15 +21,30 @@ import breeze.linalg.{Vector => BV}
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.mllib.classification.NaiveBayes.{Multinomial, Bernoulli} import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
val pi = Array(0.5, 0.1, 0.4).map(math.log)
val theta = Array(
Array(0.70, 0.10, 0.10, 0.10), // label 0
Array(0.10, 0.70, 0.10, 0.10), // label 1
Array(0.10, 0.10, 0.70, 0.10) // label 2
).map(_.map(math.log))
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(generateNaiveBayesInput(pi, theta, 100, 42))
}
def validatePrediction(predictionAndLabels: DataFrame): Unit = { def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count { val numOfErrorPredictions = predictionAndLabels.collect().count {
...@@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -161,4 +176,26 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
.select("features", "probability") .select("features", "probability")
validateProbabilities(featureAndProbabilities, model, "bernoulli") validateProbabilities(featureAndProbabilities, model, "bernoulli")
} }
test("read/write") {
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
assert(model.pi === model2.pi)
assert(model.theta === model2.theta)
}
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
}
}
object NaiveBayesSuite {
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"smoothing" -> 0.1
)
} }
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
package org.apache.spark.ml.clustering package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
...@@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext} ...@@ -25,16 +26,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
private[clustering] case class TestRow(features: Vector) private[clustering] case class TestRow(features: Vector)
object KMeansSuite { class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}
}
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
final val k = 5 final val k = 5
@transient var dataset: DataFrame = _ @transient var dataset: DataFrame = _
...@@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -106,4 +98,33 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(clusters === Set(0, 1, 2, 3, 4)) assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1) assert(model.computeCost(dataset) < 0.1)
} }
test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)
}
val kmeans = new KMeans()
testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
}
}
object KMeansSuite {
def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
val sc = sql.sparkContext
val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
.map(v => new TestRow(v))
sql.createDataFrame(rdd)
}
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPrediction",
"k" -> 3,
"maxIter" -> 2,
"tol" -> 0.01
)
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment