Skip to content
Snippets Groups Projects
Commit 63b200e8 authored by wm624@hotmail.com's avatar wm624@hotmail.com Committed by Joseph K. Bradley
Browse files

[SPARK-14071][PYSPARK][ML] Change MLWritable.write to be a property

Add property to MLWritable.write method, so we can use .write instead of .write()

Add a new test to ml/test.py to check whether the write is a property.
./python/run-tests --python-executables=python2.7 --modules=pyspark-ml

Will test against the following Python executables: ['python2.7']
Will test the following Python modules: ['pyspark-ml']
Finished test(python2.7): pyspark.ml.evaluation (11s)
Finished test(python2.7): pyspark.ml.clustering (16s)
Finished test(python2.7): pyspark.ml.classification (24s)
Finished test(python2.7): pyspark.ml.recommendation (24s)
Finished test(python2.7): pyspark.ml.feature (39s)
Finished test(python2.7): pyspark.ml.regression (26s)
Finished test(python2.7): pyspark.ml.tuning (15s)
Finished test(python2.7): pyspark.ml.tests (30s)
Tests passed in 55 seconds

Author: wm624@hotmail.com <wm624@hotmail.com>

Closes #11945 from wangmiao1981/fix_property.
parent f6066b0c
No related branches found
No related tags found
No related merge requests found
......@@ -51,6 +51,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.util import MLWritable, MLWriter
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.linalg import DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
......@@ -655,6 +656,10 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
def test_write_property(self):
lr = LinearRegression(maxIter=1)
self.assertTrue(isinstance(lr.write, MLWriter))
def test_decisiontree_classifier(self):
dt = DecisionTreeClassifier(maxDepth=1)
path = tempfile.mkdtemp()
......
......@@ -134,13 +134,14 @@ class MLWritable(object):
.. versionadded:: 2.0.0
"""
@property
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
self.write().save(path)
self.write.save(path)
@inherit_doc
......@@ -149,6 +150,7 @@ class JavaMLWritable(MLWritable):
(Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
"""
@property
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
return JavaMLWriter(self)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment