Skip to content
Snippets Groups Projects
Commit 9ca79c1e authored by Holden Karau's avatar Holden Karau Committed by Sean Owen
Browse files

[SPARK-13302][PYSPARK][TESTS] Move the temp file creation and cleanup outside of the doctests

Some of the new doctests in ml/clustering.py have a lot of setup code, move the setup code to the general test init to keep the doctest more example-style looking.
In part this is a follow up to https://github.com/apache/spark/pull/10999
Note that the same pattern is followed in regression & recommendation - might as well clean up all three at the same time.

Author: Holden Karau <holden@us.ibm.com>

Closes #11197 from holdenk/SPARK-13302-cleanup-doctests-in-ml-clustering.
parent dfb2ae2f
No related branches found
No related tags found
No related merge requests found
...@@ -70,25 +70,18 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol ...@@ -70,25 +70,18 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True True
>>> rows[2].prediction == rows[3].prediction >>> rows[2].prediction == rows[3].prediction
True True
>>> import os, tempfile >>> kmeans_path = temp_path + "/kmeans"
>>> path = tempfile.mkdtemp()
>>> kmeans_path = path + "/kmeans"
>>> kmeans.save(kmeans_path) >>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path)
>>> kmeans2.getK() >>> kmeans2.getK()
2 2
>>> model_path = path + "/kmeans_model" >>> model_path = temp_path + "/kmeans_model"
>>> model.save(model_path) >>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path) >>> model2 = KMeansModel.load(model_path)
>>> model.clusterCenters()[0] == model2.clusterCenters()[0] >>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool) array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1] >>> model.clusterCenters()[1] == model2.clusterCenters()[1]
array([ True, True], dtype=bool) array([ True, True], dtype=bool)
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.5.0 .. versionadded:: 1.5.0
""" """
...@@ -310,7 +303,17 @@ if __name__ == "__main__": ...@@ -310,7 +303,17 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc) sqlContext = SQLContext(sc)
globs['sc'] = sc globs['sc'] = sc
globs['sqlContext'] = sqlContext globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) import tempfile
sc.stop() temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count: if failure_count:
exit(-1) exit(-1)
...@@ -82,14 +82,12 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha ...@@ -82,14 +82,12 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494) Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2] >>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327) Row(user=2, item=0, prediction=-1.5018409490585327)
>>> import os, tempfile >>> als_path = temp_path + "/als"
>>> path = tempfile.mkdtemp()
>>> als_path = path + "/als"
>>> als.save(als_path) >>> als.save(als_path)
>>> als2 = ALS.load(als_path) >>> als2 = ALS.load(als_path)
>>> als.getMaxIter() >>> als.getMaxIter()
5 5
>>> model_path = path + "/als_model" >>> model_path = temp_path + "/als_model"
>>> model.save(model_path) >>> model.save(model_path)
>>> model2 = ALSModel.load(model_path) >>> model2 = ALSModel.load(model_path)
>>> model.rank == model2.rank >>> model.rank == model2.rank
...@@ -98,11 +96,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha ...@@ -98,11 +96,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
True True
>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect()) >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
True True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.4.0 .. versionadded:: 1.4.0
""" """
...@@ -340,7 +333,17 @@ if __name__ == "__main__": ...@@ -340,7 +333,17 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc) sqlContext = SQLContext(sc)
globs['sc'] = sc globs['sc'] = sc
globs['sqlContext'] = sqlContext globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) import tempfile
sc.stop() temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count: if failure_count:
exit(-1) exit(-1)
...@@ -68,25 +68,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction ...@@ -68,25 +68,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
Traceback (most recent call last): Traceback (most recent call last):
... ...
TypeError: Method setParams forces keyword arguments. TypeError: Method setParams forces keyword arguments.
>>> import os, tempfile >>> lr_path = temp_path + "/lr"
>>> path = tempfile.mkdtemp()
>>> lr_path = path + "/lr"
>>> lr.save(lr_path) >>> lr.save(lr_path)
>>> lr2 = LinearRegression.load(lr_path) >>> lr2 = LinearRegression.load(lr_path)
>>> lr2.getMaxIter() >>> lr2.getMaxIter()
5 5
>>> model_path = path + "/lr_model" >>> model_path = temp_path + "/lr_model"
>>> model.save(model_path) >>> model.save(model_path)
>>> model2 = LinearRegressionModel.load(model_path) >>> model2 = LinearRegressionModel.load(model_path)
>>> model.coefficients[0] == model2.coefficients[0] >>> model.coefficients[0] == model2.coefficients[0]
True True
>>> model.intercept == model2.intercept >>> model.intercept == model2.intercept
True True
>>> from shutil import rmtree
>>> try:
... rmtree(path)
... except OSError:
... pass
.. versionadded:: 1.4.0 .. versionadded:: 1.4.0
""" """
...@@ -850,7 +843,17 @@ if __name__ == "__main__": ...@@ -850,7 +843,17 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc) sqlContext = SQLContext(sc)
globs['sc'] = sc globs['sc'] = sc
globs['sqlContext'] = sqlContext globs['sqlContext'] = sqlContext
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) import tempfile
sc.stop() temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
finally:
from shutil import rmtree
try:
rmtree(temp_path)
except OSError:
pass
if failure_count: if failure_count:
exit(-1) exit(-1)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment