Skip to content
Snippets Groups Projects
Commit 1eda2f10 authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-14646][ML] Modified Kmeans to store cluster centers with one per row

## What changes were proposed in this pull request?

Modified Kmeans to store cluster centers with one per row

## How was this patch tested?

Existing tests

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #12792 from jkbradley/kmeans-save-fix.
parent d33e3d57
No related branches found
No related tags found
No related merge requests found
......@@ -169,18 +169,21 @@ object KMeansModel extends MLReadable[KMeansModel] {
@Since("1.6.0")
override def load(path: String): KMeansModel = super.load(path)
/** Helper class for storing model data */
private case class Data(clusterIdx: Int, clusterCenter: Vector)
/** [[MLWriter]] instance for [[KMeansModel]] */
private[KMeansModel] class KMeansModelWriter(instance: KMeansModel) extends MLWriter {
private case class Data(clusterCenters: Array[Vector])
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
// Save model data: cluster centers
val data = Data(instance.clusterCenters)
val data: Array[Data] = instance.clusterCenters.zipWithIndex.map { case (center, idx) =>
Data(idx, center)
}
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
sqlContext.createDataFrame(data).repartition(1).write.parquet(dataPath)
}
}
......@@ -190,11 +193,15 @@ object KMeansModel extends MLReadable[KMeansModel] {
private val className = classOf[KMeansModel].getName
override def load(path: String): KMeansModel = {
// Import implicits for Dataset Encoder
val sqlContext = super.sqlContext
import sqlContext.implicits._
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sqlContext.read.parquet(dataPath).select("clusterCenters").head()
val clusterCenters = data.getAs[Seq[Vector]](0).toArray
val data: Dataset[Data] = sqlContext.read.parquet(dataPath).as[Data]
val clusterCenters = data.collect().sortBy(_.clusterIdx).map(_.clusterCenter)
val model = new KMeansModel(metadata.uid, new MLlibKMeansModel(clusterCenters))
DefaultParamsReader.getAndSetParams(model, metadata)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment