diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index aa78149699a277a37c299090ddbd236bbb007a21..df2a9c0dd5094f448f0efb5e1d5388166adf9b3b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession} class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector]) extends Saveable with Serializable with PMMLExportable { + private val clusterCentersWithNorm = + if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) + /** * A Java-friendly constructor that takes an Iterable of Vectors. */ @@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec * Total number of clusters. */ @Since("0.8.0") - def k: Int = clusterCenters.length + def k: Int = clusterCentersWithNorm.length /** * Returns the cluster index that a given point belongs to. @@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec */ @Since("1.0.0") def predict(points: RDD[Vector]): RDD[Int] = { - val centersWithNorm = clusterCentersWithNorm - val bcCentersWithNorm = points.context.broadcast(centersWithNorm) + val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm) points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1) } @@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec */ @Since("0.8.0") def computeCost(data: RDD[Vector]): Double = { - val centersWithNorm = clusterCentersWithNorm - val bcCentersWithNorm = data.context.broadcast(centersWithNorm) + val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm) data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum() } - private def clusterCentersWithNorm: Iterable[VectorWithNorm] = - clusterCenters.map(new VectorWithNorm(_)) @Since("1.4.0") override def save(sc: SparkContext, path: String): Unit = { @@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] { val metadata = compact(render( ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k))) sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) - val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) => - Cluster(id, point) + val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) => + Cluster(id, p.vector) } spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 85c37c438d93a1e5e5e304e9cb1213beb964d05b..3ca75e8cdb97a328fab8dd72fb115988e5ea8702 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") ( } } - this + new StreamingKMeansModel(clusterCenters, clusterWeights) } }