Skip to content
Snippets Groups Projects
Commit 7f024c47 authored by Kai Jiang's avatar Kai Jiang Committed by Xiangrui Meng
Browse files

[SPARK-13597][PYSPARK][ML] Python API for GeneralizedLinearRegression

## What changes were proposed in this pull request?

Python API for GeneralizedLinearRegression
JIRA: https://issues.apache.org/jira/browse/SPARK-13597

## How was this patch tested?

The patch is tested with Python doctest.

Author: Kai Jiang <jiangkai@gmail.com>

Closes #11468 from vectorijk/spark-13597.
parent 101663f1
No related branches found
No related tags found
No related merge requests found
...@@ -28,6 +28,7 @@ from pyspark.sql import DataFrame ...@@ -28,6 +28,7 @@ from pyspark.sql import DataFrame
__all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'DecisionTreeRegressor', 'DecisionTreeRegressionModel',
'GBTRegressor', 'GBTRegressionModel', 'GBTRegressor', 'GBTRegressionModel',
'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel'
'IsotonicRegression', 'IsotonicRegressionModel', 'IsotonicRegression', 'IsotonicRegressionModel',
'LinearRegression', 'LinearRegressionModel', 'LinearRegression', 'LinearRegressionModel',
'LinearRegressionSummary', 'LinearRegressionTrainingSummary', 'LinearRegressionSummary', 'LinearRegressionTrainingSummary',
...@@ -1197,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): ...@@ -1197,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
return self._call_java("predict", features) return self._call_java("predict", features)
@inherit_doc
class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol,
HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol,
HasSolver, JavaMLWritable, JavaMLReadable):
"""
Generalized Linear Regression.
Fit a Generalized Linear Model specified by giving a symbolic description of the linear
predictor (link function) and a description of the error distribution (family). It supports
"gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family
is listed below. The first link function of each family is the default one.
- "gaussian" -> "identity", "log", "inverse"
- "binomial" -> "logit", "probit", "cloglog"
- "poisson" -> "log", "identity", "sqrt"
- "gamma" -> "inverse", "identity", "log"
.. seealso:: `GLM <https://en.wikipedia.org/wiki/Generalized_linear_model>`_
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(0.0, 0.0)),
... (1.0, Vectors.dense(1.0, 2.0)),
... (2.0, Vectors.dense(0.0, 0.0)),
... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"])
>>> glr = GeneralizedLinearRegression(family="gaussian", link="identity")
>>> model = glr.fit(df)
>>> abs(model.transform(df).head().prediction - 1.5) < 0.001
True
>>> model.coefficients
DenseVector([1.5..., -1.0...])
>>> abs(model.intercept - 1.5) < 0.001
True
>>> glr_path = temp_path + "/glr"
>>> glr.save(glr_path)
>>> glr2 = GeneralizedLinearRegression.load(glr_path)
>>> glr.getFamily() == glr2.getFamily()
True
>>> model_path = temp_path + "/glr_model"
>>> model.save(model_path)
>>> model2 = GeneralizedLinearRegressionModel.load(model_path)
>>> model.intercept == model2.intercept
True
>>> model.coefficients[0] == model2.coefficients[0]
True
.. versionadded:: 2.0.0
"""
family = Param(Params._dummy(), "family", "The name of family which is a description of " +
"the error distribution to be used in the model. Supported options: " +
"gaussian(default), binomial, poisson and gamma.")
link = Param(Params._dummy(), "link", "The name of link function which provides the " +
"relationship between the linear predictor and the mean of the distribution " +
"function. Supported options: identity, log, inverse, logit, probit, cloglog " +
"and sqrt.")
@keyword_only
def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls"):
"""
__init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls")
"""
super(GeneralizedLinearRegression, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid)
self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("2.0.0")
def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction",
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6,
regParam=0.0, weightCol=None, solver="irls"):
"""
setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \
family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \
regParam=0.0, weightCol=None, solver="irls")
Sets params for generalized linear regression.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
def _create_model(self, java_model):
return GeneralizedLinearRegressionModel(java_model)
@since("2.0.0")
def setFamily(self, value):
"""
Sets the value of :py:attr:`family`.
"""
self._paramMap[self.family] = value
return self
@since("2.0.0")
def getFamily(self):
"""
Gets the value of family or its default value.
"""
return self.getOrDefault(self.family)
@since("2.0.0")
def setLink(self, value):
"""
Sets the value of :py:attr:`link`.
"""
self._paramMap[self.link] = value
return self
@since("2.0.0")
def getLink(self):
"""
Gets the value of link or its default value.
"""
return self.getOrDefault(self.link)
class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by GeneralizedLinearRegression.
.. versionadded:: 2.0.0
"""
@property
@since("2.0.0")
def coefficients(self):
"""
Model coefficients.
"""
return self._call_java("coefficients")
@property
@since("2.0.0")
def intercept(self):
"""
Model intercept.
"""
return self._call_java("intercept")
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
import pyspark.ml.regression import pyspark.ml.regression
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment