Skip to content
Snippets Groups Projects
Commit 603a721c authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol

[SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9811 from yanboliang/spark-11820.
parent e222d758
No related branches found
No related tags found
No related merge requests found
...@@ -36,7 +36,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassif ...@@ -36,7 +36,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassif
@inherit_doc @inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds): HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
HasWeightCol):
""" """
Logistic regression. Logistic regression.
Currently, this class only supports binary classification. Currently, this class only supports binary classification.
...@@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ...@@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> from pyspark.sql import Row >>> from pyspark.sql import Row
>>> from pyspark.mllib.linalg import Vectors >>> from pyspark.mllib.linalg import Vectors
>>> df = sc.parallelize([ >>> df = sc.parallelize([
... Row(label=1.0, features=Vectors.dense(1.0)), ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF() ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF()
>>> lr = LogisticRegression(maxIter=5, regParam=0.01) >>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
>>> model = lr.fit(df) >>> model = lr.fit(df)
>>> model.weights >>> model.weights
DenseVector([5.5...]) DenseVector([5.5...])
...@@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ...@@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability", threshold=0.5, thresholds=None, probabilityCol="probability",
rawPredictionCol="rawPrediction", standardization=True): rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
""" """
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \ threshold=0.5, thresholds=None, probabilityCol="probability", \
rawPredictionCol="rawPrediction", standardization=True) rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
If the threshold and thresholds Params are both set, they must be equivalent. If the threshold and thresholds Params are both set, they must be equivalent.
""" """
super(LogisticRegression, self).__init__() super(LogisticRegression, self).__init__()
...@@ -105,12 +106,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ...@@ -105,12 +106,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability", threshold=0.5, thresholds=None, probabilityCol="probability",
rawPredictionCol="rawPrediction", standardization=True): rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
""" """
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \ threshold=0.5, thresholds=None, probabilityCol="probability", \
rawPredictionCol="rawPrediction", standardization=True) rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
Sets params for logistic regression. Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent. If the threshold and thresholds Params are both set, they must be equivalent.
""" """
......
...@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', ...@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc @inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
HasStandardization, HasSolver): HasStandardization, HasSolver, HasWeightCol):
""" """
Linear regression. Linear regression.
...@@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction ...@@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> from pyspark.mllib.linalg import Vectors >>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([ >>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)), ... (1.0, 2.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) ... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
>>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal") >>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight")
>>> model = lr.fit(df) >>> model = lr.fit(df)
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001 >>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001
...@@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction ...@@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@keyword_only @keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto"): standardization=True, solver="auto", weightCol=None):
""" """
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto") standardization=True, solver="auto", weightCol=None)
""" """
super(LinearRegression, self).__init__() super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj( self._java_obj = self._new_java_obj(
...@@ -92,11 +92,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction ...@@ -92,11 +92,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@since("1.4.0") @since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto"): standardization=True, solver="auto", weightCol=None):
""" """
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \ maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto") standardization=True, solver="auto", weightCol=None)
Sets params for linear regression. Sets params for linear regression.
""" """
kwargs = self.setParams._input_kwargs kwargs = self.setParams._input_kwargs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment