Skip to content
Snippets Groups Projects
Commit 1052d364 authored by Yanbo Liang's avatar Yanbo Liang Committed by Joseph K. Bradley
Browse files

[SPARK-15362][ML] Make spark.ml KMeansModel load backwards compatible

## What changes were proposed in this pull request?
[SPARK-14646](https://issues.apache.org/jira/browse/SPARK-14646) makes ```KMeansModel``` store the cluster centers one per row. ```KMeansModel.load()``` method needs to be updated in order to load models saved with Spark 1.6.

## How was this patch tested?
Since ```save/load``` is ```Experimental``` for 1.6, I think offline test for backwards compatibility is enough.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #13149 from yanboliang/spark-15362.
parent 3facca51
No related branches found
No related tags found
No related merge requests found
......@@ -185,6 +185,12 @@ object KMeansModel extends MLReadable[KMeansModel] {
/** Helper class for storing model data */
private case class Data(clusterIdx: Int, clusterCenter: Vector)
/**
* We store all cluster centers in a single row and use this class to store model data by
* Spark 1.6 and earlier. A model can be loaded from such older data for backward compatibility.
*/
private case class OldData(clusterCenters: Array[OldVector])
/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
......@@ -211,13 +217,19 @@ object KMeansModel extends MLReadable[KMeansModel] {
import sqlContext.implicits._
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter)
val model = new KMeansModel(metadata.uid,
new MLlibKMeansModel(clusterCenters.map(OldVectors.fromML)))
val versionRegex = "([0-9]+)\\.(.+)".r
val versionRegex(major, _) = metadata.sparkVersion
val clusterCenters = if (major.toInt >= 2) {
val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
// Loads KMeansModel stored with the old format used by Spark 1.6 and earlier.
sqlContext.read.parquet(dataPath).as[OldData].head().clusterCenters
}
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment