Repository: spark
Updated Branches:
  refs/heads/master 45d40d9f6 -> 39f328ba3


[SPARK-15018][PYSPARK][ML] Improve handling of PySpark Pipeline when used 
without stages

## What changes were proposed in this pull request?

When fitting a PySpark Pipeline without the `stages` param set, a confusing 
NoneType error is raised as attempts to iterate over the pipeline stages.  A 
pipeline with no stages should act as an identity transform, however the 
`stages` param still needs to be set to an empty list.  This change improves 
the error output when the `stages` param is not set and adds a better 
description of what the API expects as input.  Also minor cleanup of related 
code.

## How was this patch tested?
Added new unit tests to verify an empty Pipeline acts as an identity transformer

Author: Bryan Cutler <cutl...@gmail.com>

Closes #12790 from BryanCutler/pipeline-identity-SPARK-15018.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/39f328ba
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/39f328ba
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/39f328ba

Branch: refs/heads/master
Commit: 39f328ba3519b01940a7d1cdee851ba4e75ef31f
Parents: 45d40d9
Author: Bryan Cutler <cutl...@gmail.com>
Authored: Fri Aug 19 23:46:36 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Fri Aug 19 23:46:36 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/pipeline.py | 11 +++--------
 python/pyspark/ml/tests.py    | 11 +++++++++++
 2 files changed, 14 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/39f328ba/python/pyspark/ml/pipeline.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a48f4bb..4307ad0 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -44,21 +44,19 @@ class Pipeline(Estimator, MLReadable, MLWritable):
     the dataset for the next stage. The fitted model from a
     :py:class:`Pipeline` is a :py:class:`PipelineModel`, which
     consists of fitted models and transformers, corresponding to the
-    pipeline stages. If there are no stages, the pipeline acts as an
+    pipeline stages. If stages is an empty list, the pipeline acts as an
     identity transformer.
 
     .. versionadded:: 1.3.0
     """
 
-    stages = Param(Params._dummy(), "stages", "pipeline stages")
+    stages = Param(Params._dummy(), "stages", "a list of pipeline stages")
 
     @keyword_only
     def __init__(self, stages=None):
         """
         __init__(self, stages=None)
         """
-        if stages is None:
-            stages = []
         super(Pipeline, self).__init__()
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
@@ -78,8 +76,7 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         """
         Get pipeline stages.
         """
-        if self.stages in self._paramMap:
-            return self._paramMap[self.stages]
+        return self.getOrDefault(self.stages)
 
     @keyword_only
     @since("1.3.0")
@@ -88,8 +85,6 @@ class Pipeline(Estimator, MLReadable, MLWritable):
         setParams(self, stages=None)
         Sets params for Pipeline.
         """
-        if stages is None:
-            stages = []
         kwargs = self.setParams._input_kwargs
         return self._set(**kwargs)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/39f328ba/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 4bcb2c4..6886ed3 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -230,6 +230,17 @@ class PipelineTests(PySparkTestCase):
         self.assertEqual(5, transformer3.dataset_index)
         self.assertEqual(6, dataset.index)
 
+    def test_identity_pipeline(self):
+        dataset = MockDataset()
+
+        def doTransform(pipeline):
+            pipeline_model = pipeline.fit(dataset)
+            return pipeline_model.transform(dataset)
+        # check that empty pipeline did not perform any transformation
+        self.assertEqual(dataset.index, doTransform(Pipeline(stages=[])).index)
+        # check that failure to set stages param will raise KeyError for 
missing param
+        self.assertRaises(KeyError, lambda: doTransform(Pipeline()))
+
 
 class TestParams(HasMaxIter, HasInputCol, HasSeed):
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to