Skip to content
Snippets Groups Projects
Commit 39f328ba authored by Bryan Cutler's avatar Bryan Cutler Committed by Yanbo Liang
Browse files

[SPARK-15018][PYSPARK][ML] Improve handling of PySpark Pipeline when used without stages

## What changes were proposed in this pull request?

When fitting a PySpark Pipeline without the `stages` param set, a confusing NoneType error is raised as attempts to iterate over the pipeline stages.  A pipeline with no stages should act as an identity transform, however the `stages` param still needs to be set to an empty list.  This change improves the error output when the `stages` param is not set and adds a better description of what the API expects as input.  Also minor cleanup of related code.

## How was this patch tested?
Added new unit tests to verify an empty Pipeline acts as an identity transformer

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #12790 from BryanCutler/pipeline-identity-SPARK-15018.
parent 45d40d9f
No related branches found
No related tags found
No related merge requests found
......@@ -44,21 +44,19 @@ class Pipeline(Estimator, MLReadable, MLWritable):
the dataset for the next stage. The fitted model from a
:py:class:`Pipeline` is a :py:class:`PipelineModel`, which
consists of fitted models and transformers, corresponding to the
pipeline stages. If there are no stages, the pipeline acts as an
pipeline stages. If stages is an empty list, the pipeline acts as an
identity transformer.
.. versionadded:: 1.3.0
"""
stages = Param(Params._dummy(), "stages", "pipeline stages")
stages = Param(Params._dummy(), "stages", "a list of pipeline stages")
@keyword_only
def __init__(self, stages=None):
"""
__init__(self, stages=None)
"""
if stages is None:
stages = []
super(Pipeline, self).__init__()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
......@@ -78,8 +76,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
"""
Get pipeline stages.
"""
if self.stages in self._paramMap:
return self._paramMap[self.stages]
return self.getOrDefault(self.stages)
@keyword_only
@since("1.3.0")
......@@ -88,8 +85,6 @@ class Pipeline(Estimator, MLReadable, MLWritable):
setParams(self, stages=None)
Sets params for Pipeline.
"""
if stages is None:
stages = []
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
......
......@@ -230,6 +230,17 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(5, transformer3.dataset_index)
self.assertEqual(6, dataset.index)
def test_identity_pipeline(self):
dataset = MockDataset()
def doTransform(pipeline):
pipeline_model = pipeline.fit(dataset)
return pipeline_model.transform(dataset)
# check that empty pipeline did not perform any transformation
self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
# check that failure to set stages param will raise KeyError for missing param
self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
class TestParams(HasMaxIter, HasInputCol, HasSeed):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment