Skip to content
Snippets Groups Projects
Unverified Commit 445d4d9e authored by Zakaria_Hili's avatar Zakaria_Hili Committed by Sean Owen
Browse files

[SPARK-18356][ML] Improve MLKmeans Performance

## What changes were proposed in this pull request?

Spark Kmeans fit() doesn't cache the RDD which generates a lot of warnings :
 WARN KMeans: The input data is not directly cached, which may hurt performance if its parent RDDs are also uncached.
So, Kmeans should cache the internal rdd before calling the Mllib.Kmeans algo, this helped to improve spark kmeans performance by 14%

https://github.com/ZakariaHili/spark/commit/a9cf905cf7dbd50eeb9a8b4f891f2f41ea672472

hhbyyh
## How was this patch tested?
Pass Kmeans tests and existing tests

Author: Zakaria_Hili <zakahili@gmail.com>
Author: HILI Zakaria <zakahili@gmail.com>

Closes #15965 from ZakariaHili/zakbranch.
parent 5ecdc7c5
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.VersionUtils.majorVersion
/**
......@@ -305,12 +306,20 @@ class KMeans @Since("1.5.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): KMeansModel = {
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
fit(dataset, handlePersistence)
}
@Since("2.2.0")
protected def fit(dataset: Dataset[_], handlePersistence: Boolean): KMeansModel = {
transformSchema(dataset.schema, logging = true)
val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
val instances: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
case Row(point: Vector) => OldVectors.fromML(point)
}
val instr = Instrumentation.create(this, rdd)
if (handlePersistence) {
instances.persist(StorageLevel.MEMORY_AND_DISK)
}
val instr = Instrumentation.create(this, instances)
instr.logParams(featuresCol, predictionCol, k, initMode, initSteps, maxIter, seed, tol)
val algo = new MLlibKMeans()
......@@ -320,12 +329,15 @@ class KMeans @Since("1.5.0") (
.setMaxIterations($(maxIter))
.setSeed($(seed))
.setEpsilon($(tol))
val parentModel = algo.run(rdd, Option(instr))
val parentModel = algo.run(instances, Option(instr))
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
val summary = new KMeansSummary(
model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(Some(summary))
instr.logSuccess(model)
if (handlePersistence) {
instances.unpersist()
}
model
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment