diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 60d1c9aaec9887011f490e94569651ade8b2b472..12afb885636332865536d4a9149862e8774b11be 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -113,10 +113,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol def setK(self, value): """ Sets the value of :py:attr:`k`. - - >>> algo = KMeans().setK(10) - >>> algo.getK() - 10 """ self._paramMap[self.k] = value return self @@ -132,13 +128,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol def setInitMode(self, value): """ Sets the value of :py:attr:`initMode`. - - >>> algo = KMeans() - >>> algo.getInitMode() - 'k-means||' - >>> algo = algo.setInitMode("random") - >>> algo.getInitMode() - 'random' """ self._paramMap[self.initMode] = value return self @@ -154,10 +143,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol def setInitSteps(self, value): """ Sets the value of :py:attr:`initSteps`. - - >>> algo = KMeans().setInitSteps(10) - >>> algo.getInitSteps() - 10 """ self._paramMap[self.initSteps] = value return self diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 54806ee336666f4f7665164aea3e19bac64f0c6a..e93a4e157b931b142c40fec44bb2ed6aefd2f4fe 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -39,6 +39,7 @@ import tempfile from pyspark.ml import Estimator, Model, Pipeline, Transformer from pyspark.ml.classification import LogisticRegression +from pyspark.ml.clustering import KMeans from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params @@ -243,6 +244,14 @@ class ParamTests(PySparkTestCase): "maxIter: max number of iterations (>= 0). (default: 10, current: 100)", "seed: random seed. (default: 41, current: 43)"])) + def test_kmeans_param(self): + algo = KMeans() + self.assertEqual(algo.getInitMode(), "k-means||") + algo.setK(10) + self.assertEqual(algo.getK(), 10) + algo.setInitSteps(10) + self.assertEqual(algo.getInitSteps(), 10) + def test_hasseed(self): noSeedSpecd = TestParams() withSeedSpecd = TestParams(seed=42)