Skip to content
Snippets Groups Projects
Commit 603a721c authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-11820][ML][PYSPARK] PySpark LiR & LoR should support weightCol

[SPARK-7685](https://issues.apache.org/jira/browse/SPARK-7685) and [SPARK-9642](https://issues.apache.org/jira/browse/SPARK-9642) have already supported setting weight column for ```LogisticRegression``` and ```LinearRegression```. It's a very important feature, PySpark should also support. mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9811 from yanboliang/spark-11820.
parent e222d758
No related branches found
No related tags found
No related merge requests found
......@@ -36,7 +36,8 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', 'DecisionTreeClassif
@inherit_doc
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds):
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
HasWeightCol):
"""
Logistic regression.
Currently, this class only supports binary classification.
......@@ -44,9 +45,9 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> from pyspark.sql import Row
>>> from pyspark.mllib.linalg import Vectors
>>> df = sc.parallelize([
... Row(label=1.0, features=Vectors.dense(1.0)),
... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF()
>>> lr = LogisticRegression(maxIter=5, regParam=0.01)
... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)),
... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF()
>>> lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight")
>>> model = lr.fit(df)
>>> model.weights
DenseVector([5.5...])
......@@ -80,12 +81,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
rawPredictionCol="rawPrediction", standardization=True):
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
rawPredictionCol="rawPrediction", standardization=True)
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
If the threshold and thresholds Params are both set, they must be equivalent.
"""
super(LogisticRegression, self).__init__()
......@@ -105,12 +106,12 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
threshold=0.5, thresholds=None, probabilityCol="probability",
rawPredictionCol="rawPrediction", standardization=True):
rawPredictionCol="rawPrediction", standardization=True, weightCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.1, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
threshold=0.5, thresholds=None, probabilityCol="probability", \
rawPredictionCol="rawPrediction", standardization=True)
rawPredictionCol="rawPrediction", standardization=True, weightCol=None)
Sets params for logistic regression.
If the threshold and thresholds Params are both set, they must be equivalent.
"""
......
......@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
HasStandardization, HasSolver):
HasStandardization, HasSolver, HasWeightCol):
"""
Linear regression.
......@@ -50,9 +50,9 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal")
... (1.0, 2.0, Vectors.dense(1.0)),
... (0.0, 2.0, Vectors.sparse(1, [], []))], ["label", "weight", "features"])
>>> lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight")
>>> model = lr.fit(df)
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> abs(model.transform(test0).head().prediction - (-1.0)) < 0.001
......@@ -75,11 +75,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto"):
standardization=True, solver="auto", weightCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto")
standardization=True, solver="auto", weightCol=None)
"""
super(LinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
......@@ -92,11 +92,11 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
@since("1.4.0")
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True,
standardization=True, solver="auto"):
standardization=True, solver="auto", weightCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \
standardization=True, solver="auto")
standardization=True, solver="auto", weightCol=None)
Sets params for linear regression.
"""
kwargs = self.setParams._input_kwargs
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment