Skip to content
Snippets Groups Projects
Unverified Commit 56d3a7eb authored by Sean Owen's avatar Sean Owen
Browse files

[SPARK-18808][ML][MLLIB] ml.KMeansModel.transform is very inefficient

## What changes were proposed in this pull request?

mllib.KMeansModel.clusterCentersWithNorm is a method than ends up being called every time `predict` is called on a single vector, which is bad news for now the ml.KMeansModel Transformer works, which necessarily transforms one vector at a time.

This causes the model to just store the vectors with norms upfront. The extra norm should be small compared to the vectors. This would avoid this form of overhead on this and other code paths.

## How was this patch tested?

Existing tests.

Author: Sean Owen <sowen@cloudera.com>

Closes #16328 from srowen/SPARK-18808.
parent 63036aee
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,9 @@ import org.apache.spark.sql.{Row, SparkSession}
class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vector])
extends Saveable with Serializable with PMMLExportable {
private val clusterCentersWithNorm =
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_))
/**
* A Java-friendly constructor that takes an Iterable of Vectors.
*/
......@@ -49,7 +52,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
* Total number of clusters.
*/
@Since("0.8.0")
def k: Int = clusterCenters.length
def k: Int = clusterCentersWithNorm.length
/**
* Returns the cluster index that a given point belongs to.
......@@ -64,8 +67,7 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("1.0.0")
def predict(points: RDD[Vector]): RDD[Int] = {
val centersWithNorm = clusterCentersWithNorm
val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
val bcCentersWithNorm = points.context.broadcast(clusterCentersWithNorm)
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new VectorWithNorm(p))._1)
}
......@@ -82,13 +84,10 @@ class KMeansModel @Since("1.1.0") (@Since("1.0.0") val clusterCenters: Array[Vec
*/
@Since("0.8.0")
def computeCost(data: RDD[Vector]): Double = {
val centersWithNorm = clusterCentersWithNorm
val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
val bcCentersWithNorm = data.context.broadcast(clusterCentersWithNorm)
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new VectorWithNorm(p))).sum()
}
private def clusterCentersWithNorm: Iterable[VectorWithNorm] =
clusterCenters.map(new VectorWithNorm(_))
@Since("1.4.0")
override def save(sc: SparkContext, path: String): Unit = {
......@@ -127,8 +126,8 @@ object KMeansModel extends Loader[KMeansModel] {
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("k" -> model.k)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
val dataRDD = sc.parallelize(model.clusterCenters.zipWithIndex).map { case (point, id) =>
Cluster(id, point)
val dataRDD = sc.parallelize(model.clusterCentersWithNorm.zipWithIndex).map { case (p, id) =>
Cluster(id, p.vector)
}
spark.createDataFrame(dataRDD).write.parquet(Loader.dataPath(path))
}
......
......@@ -145,7 +145,7 @@ class StreamingKMeansModel @Since("1.2.0") (
}
}
this
new StreamingKMeansModel(clusterCenters, clusterWeights)
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment