Skip to content
Snippets Groups Projects
Commit 723853ed authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7648] [MLLIB] Add weights and intercept to GLM wrappers in spark.ml

Otherwise, users can only use `transform` on the models. brkyvz

Author: Xiangrui Meng <meng@databricks.com>

Closes #6156 from mengxr/SPARK-7647 and squashes the following commits:

1ae3d2d [Xiangrui Meng] add weights and intercept to LogisticRegression in Python
f49eb46 [Xiangrui Meng] add weights and intercept to LinearRegressionModel
parent b208f998
No related branches found
No related tags found
No related merge requests found
...@@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti ...@@ -43,6 +43,10 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
>>> model.transform(test0).head().prediction >>> model.transform(test0).head().prediction
0.0 0.0
>>> model.weights
DenseVector([5.5...])
>>> model.intercept
-2.68...
>>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
>>> model.transform(test1).head().prediction >>> model.transform(test1).head().prediction
1.0 1.0
...@@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel): ...@@ -148,6 +152,20 @@ class LogisticRegressionModel(JavaModel):
Model fitted by LogisticRegression. Model fitted by LogisticRegression.
""" """
@property
def weights(self):
"""
Model weights.
"""
return self._call_java("weights")
@property
def intercept(self):
"""
Model intercept.
"""
return self._call_java("intercept")
class TreeClassifierParams(object): class TreeClassifierParams(object):
""" """
......
...@@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction ...@@ -51,6 +51,10 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction >>> model.transform(test0).head().prediction
-1.0 -1.0
>>> model.weights
DenseVector([1.0])
>>> model.intercept
0.0
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction >>> model.transform(test1).head().prediction
1.0 1.0
...@@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel): ...@@ -117,6 +121,20 @@ class LinearRegressionModel(JavaModel):
Model fitted by LinearRegression. Model fitted by LinearRegression.
""" """
@property
def weights(self):
"""
Model weights.
"""
return self._call_java("weights")
@property
def intercept(self):
"""
Model intercept.
"""
return self._call_java("intercept")
class TreeRegressorParams(object): class TreeRegressorParams(object):
""" """
......
...@@ -21,7 +21,7 @@ from pyspark import SparkContext ...@@ -21,7 +21,7 @@ from pyspark import SparkContext
from pyspark.sql import DataFrame from pyspark.sql import DataFrame
from pyspark.ml.param import Params from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model
from pyspark.mllib.common import inherit_doc from pyspark.mllib.common import inherit_doc, _java2py, _py2java
def _jvm(): def _jvm():
...@@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer): ...@@ -149,6 +149,12 @@ class JavaModel(Model, JavaTransformer):
def _java_obj(self): def _java_obj(self):
return self._java_model return self._java_model
def _call_java(self, name, *args):
m = getattr(self._java_model, name)
sc = SparkContext._active_spark_context
java_args = [_py2java(sc, arg) for arg in args]
return _java2py(sc, m(*java_args))
@inherit_doc @inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper): class JavaEvaluator(Evaluator, JavaWrapper):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment