Skip to content
Snippets Groups Projects
Commit 4c82ca86 authored by Jeff Zhang's avatar Jeff Zhang Committed by Yanbo Liang
Browse files

[SPARK-15819][PYSPARK][ML] Add KMeanSummary in KMeans of PySpark

## What changes were proposed in this pull request?

Add python api for KMeansSummary
## How was this patch tested?

unit test added

Author: Jeff Zhang <zjffdu@apache.org>

Closes #13557 from zjffdu/SPARK-15819.
parent 489845f3
No related branches found
No related tags found
No related merge requests found
...@@ -292,6 +292,17 @@ class GaussianMixtureSummary(ClusteringSummary): ...@@ -292,6 +292,17 @@ class GaussianMixtureSummary(ClusteringSummary):
return self._call_java("probability") return self._call_java("probability")
class KMeansSummary(ClusteringSummary):
"""
.. note:: Experimental
Summary of KMeans.
.. versionadded:: 2.1.0
"""
pass
class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
""" """
Model fitted by KMeans. Model fitted by KMeans.
...@@ -312,6 +323,27 @@ class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): ...@@ -312,6 +323,27 @@ class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
""" """
return self._call_java("computeCost", dataset) return self._call_java("computeCost", dataset)
@property
@since("2.1.0")
def hasSummary(self):
"""
Indicates whether a training summary exists for this model instance.
"""
return self._call_java("hasSummary")
@property
@since("2.1.0")
def summary(self):
"""
Gets summary (e.g. cluster assignments, cluster sizes) of the model trained on the
training set. An exception is thrown if no summary exists.
"""
if self.hasSummary:
return KMeansSummary(self._call_java("summary"))
else:
raise RuntimeError("No training summary available for this %s" %
self.__class__.__name__)
@inherit_doc @inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
...@@ -337,6 +369,13 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol ...@@ -337,6 +369,13 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True True
>>> rows[2].prediction == rows[3].prediction >>> rows[2].prediction == rows[3].prediction
True True
>>> model.hasSummary
True
>>> summary = model.summary
>>> summary.k
2
>>> summary.clusterSizes
[2, 2]
>>> kmeans_path = temp_path + "/kmeans" >>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path) >>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path)
...@@ -345,6 +384,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol ...@@ -345,6 +384,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
>>> model_path = temp_path + "/kmeans_model" >>> model_path = temp_path + "/kmeans_model"
>>> model.save(model_path) >>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path) >>> model2 = KMeansModel.load(model_path)
>>> model2.hasSummary
False
>>> model.clusterCenters()[0] == model2.clusterCenters()[0] >>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool) array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1] >>> model.clusterCenters()[1] == model2.clusterCenters()[1]
......
...@@ -1129,6 +1129,21 @@ class TrainingSummaryTest(SparkSessionTestCase): ...@@ -1129,6 +1129,21 @@ class TrainingSummaryTest(SparkSessionTestCase):
self.assertEqual(len(s.clusterSizes), 2) self.assertEqual(len(s.clusterSizes), 2)
self.assertEqual(s.k, 2) self.assertEqual(s.k, 2)
def test_kmeans_summary(self):
data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
(Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),)]
df = self.spark.createDataFrame(data, ["features"])
kmeans = KMeans(k=2, seed=1)
model = kmeans.fit(df)
self.assertTrue(model.hasSummary)
s = model.summary
self.assertTrue(isinstance(s.predictions, DataFrame))
self.assertEqual(s.featuresCol, "features")
self.assertEqual(s.predictionCol, "prediction")
self.assertTrue(isinstance(s.cluster, DataFrame))
self.assertEqual(len(s.clusterSizes), 2)
self.assertEqual(s.k, 2)
class OneVsRestTests(SparkSessionTestCase): class OneVsRestTests(SparkSessionTestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment