From 69786ea3a972af1b29a332dc11ac507ed4368cc6 Mon Sep 17 00:00:00 2001 From: zero323 <zero323@users.noreply.github.com> Date: Wed, 10 May 2017 16:57:52 +0800 Subject: [PATCH] [SPARK-20631][PYTHON][ML] LogisticRegression._checkThresholdConsistency should use values not Params ## What changes were proposed in this pull request? - Replace `getParam` calls with `getOrDefault` calls. - Fix exception message to avoid unintended `TypeError`. - Add unit tests ## How was this patch tested? New unit tests. Author: zero323 <zero323@users.noreply.github.com> Closes #17891 from zero323/SPARK-20631. (cherry picked from commit 804949c6bf00b8e26c39d48bbcc4d0470ee84e47) Signed-off-by: Yanbo Liang <ybliang8@gmail.com> --- python/pyspark/ml/classification.py | 6 +++--- python/pyspark/ml/tests.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 570a414cc3..2b47c40267 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -238,13 +238,13 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti def _checkThresholdConsistency(self): if self.isSet(self.threshold) and self.isSet(self.thresholds): - ts = self.getParam(self.thresholds) + ts = self.getOrDefault(self.thresholds) if len(ts) != 2: raise ValueError("Logistic Regression getThreshold only applies to" + " binary classification, but thresholds has length != 2." + - " thresholds: " + ",".join(ts)) + " thresholds: {0}".format(str(ts))) t = 1.0/(1.0 + ts[0]/ts[1]) - t2 = self.getParam(self.threshold) + t2 = self.getOrDefault(self.threshold) if abs(t2 - t) >= 1E-5: raise ValueError("Logistic Regression getThreshold found inconsistent values for" + " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 70e0c6de4a..7152036e38 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -808,6 +808,18 @@ class PersistenceTest(SparkSessionTestCase): except OSError: pass + def logistic_regression_check_thresholds(self): + self.assertIsInstance( + LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]), + LogisticRegressionModel + ) + + self.assertRaisesRegexp( + ValueError, + "Logistic Regression getThreshold found inconsistent.*$", + LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5] + ) + def _compare_params(self, m1, m2, param): """ Compare 2 ML Params instances for the given param, and assert both have the same param value -- GitLab