Skip to content
Snippets Groups Projects
Commit 30e00955 authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-13035][ML][PYSPARK] PySpark ml.clustering support export/import

PySpark ml.clustering support export/import.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #10999 from yanboliang/spark-13035.
parent 2426eb3e
No related branches found
No related tags found
No related merge requests found
......@@ -16,7 +16,7 @@
#
from pyspark import since
from pyspark.ml.util import keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
from pyspark.mllib.common import inherit_doc
......@@ -24,7 +24,7 @@ from pyspark.mllib.common import inherit_doc
__all__ = ['KMeans', 'KMeansModel']
class KMeansModel(JavaModel):
class KMeansModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by KMeans.
......@@ -46,7 +46,8 @@ class KMeansModel(JavaModel):
@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed):
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
MLWritable, MLReadable):
"""
K-means clustering with support for multiple parallel runs and a k-means++ like initialization
mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
......@@ -69,6 +70,25 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> kmeans_path = path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
>>> kmeans2.getK()
2
>>> model_path = path + "/kmeans_model"
>>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path)
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
array([ True, True], dtype=bool)
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.5.0
"""
......@@ -157,9 +177,10 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
if __name__ == "__main__":
import doctest
import pyspark.ml.clustering
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
globs = pyspark.ml.clustering.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.clustering tests")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment