Skip to content
Snippets Groups Projects
Commit da936fbb authored by Evan Chen's avatar Evan Chen Committed by Joseph K. Bradley
Browse files

[SPARK-10779] [PYSPARK] [MLLIB] Set initialModel for KMeans model in PySpark (spark.mllib)

Provide initialModel param for pyspark.mllib.clustering.KMeans

Author: Evan Chen <chene@us.ibm.com>

Closes #8967 from evanyc15/SPARK-10779-pyspark-mllib.
parent 713e4f44
No related branches found
No related tags found
No related merge requests found
......@@ -336,7 +336,8 @@ private[python] class PythonMLLibAPI extends Serializable {
initializationMode: String,
seed: java.lang.Long,
initializationSteps: Int,
epsilon: Double): KMeansModel = {
epsilon: Double,
initialModel: java.util.ArrayList[Vector]): KMeansModel = {
val kMeansAlg = new KMeans()
.setK(k)
.setMaxIterations(maxIterations)
......@@ -346,6 +347,7 @@ private[python] class PythonMLLibAPI extends Serializable {
.setEpsilon(epsilon)
if (seed != null) kMeansAlg.setSeed(seed)
if (!initialModel.isEmpty()) kMeansAlg.setInitialModel(new KMeansModel(initialModel))
try {
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
......
......@@ -90,6 +90,12 @@ class KMeansModel(Saveable, Loader):
... rmtree(path)
... except OSError:
... pass
>>> data = array([-383.1,-382.9, 28.7,31.2, 366.2,367.3]).reshape(3, 2)
>>> model = KMeans.train(sc.parallelize(data), 3, maxIterations=0,
... initialModel = KMeansModel([(-1000.0,-1000.0),(5.0,5.0),(1000.0,1000.0)]))
>>> model.clusterCenters
[array([-1000., -1000.]), array([ 5., 5.]), array([ 1000., 1000.])]
"""
def __init__(self, centers):
......@@ -144,10 +150,17 @@ class KMeans(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||",
seed=None, initializationSteps=5, epsilon=1e-4):
seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None):
"""Train a k-means clustering model."""
clusterInitialModel = []
if initialModel is not None:
if not isinstance(initialModel, KMeansModel):
raise Exception("initialModel is of "+str(type(initialModel))+". It needs "
"to be of <type 'KMeansModel'>")
clusterInitialModel = [_convert_to_vector(c) for c in initialModel.clusterCenters]
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
runs, initializationMode, seed, initializationSteps, epsilon)
runs, initializationMode, seed, initializationSteps, epsilon,
clusterInitialModel)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment