Skip to content
Snippets Groups Projects
Commit 65c696ec authored by Ram Sriharsha's avatar Ram Sriharsha Committed by Xiangrui Meng
Browse files

[SPARK-7833] [ML] Add python wrapper for RegressionEvaluator

Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes #6365 from harsha2010/SPARK-7833 and squashes the following commits:

923f288 [Ram Sriharsha] cleanup
7623b7d [Ram Sriharsha] python style fix
9743f83 [Ram Sriharsha] [SPARK-7833][ml] Add python wrapper for RegressionEvaluator
parent ed21476b
No related branches found
No related tags found
No related merge requests found
...@@ -31,14 +31,14 @@ import org.apache.spark.sql.types.DoubleType ...@@ -31,14 +31,14 @@ import org.apache.spark.sql.types.DoubleType
* Evaluator for regression, which expects two input columns: prediction and label. * Evaluator for regression, which expects two input columns: prediction and label.
*/ */
@AlphaComponent @AlphaComponent
class RegressionEvaluator(override val uid: String) final class RegressionEvaluator(override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol { extends Evaluator with HasPredictionCol with HasLabelCol {
def this() = this(Identifiable.randomUID("regEval")) def this() = this(Identifiable.randomUID("regEval"))
/** /**
* param for metric name in evaluation * param for metric name in evaluation
* @group param * @group param supports mse, rmse, r2, mae as valid metric names.
*/ */
val metricName: Param[String] = { val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae")) val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
......
...@@ -39,6 +39,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext { ...@@ -39,6 +39,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
val dataset = sqlContext.createDataFrame( val dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput( sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2)) 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
/** /**
* Using the following R code to load the data, train the model and evaluate metrics. * Using the following R code to load the data, train the model and evaluate metrics.
* *
......
...@@ -19,11 +19,11 @@ from abc import abstractmethod, ABCMeta ...@@ -19,11 +19,11 @@ from abc import abstractmethod, ABCMeta
from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.param import Param, Params from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only from pyspark.ml.util import keyword_only
from pyspark.mllib.common import inherit_doc from pyspark.mllib.common import inherit_doc
__all__ = ['Evaluator', 'BinaryClassificationEvaluator'] __all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator']
@inherit_doc @inherit_doc
...@@ -148,6 +148,70 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction ...@@ -148,6 +148,70 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
return self._set(**kwargs) return self._set(**kwargs)
@inherit_doc
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
"""
Evaluator for Regression, which expects two input
columns: prediction and label.
>>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),
... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
>>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"])
...
>>> evaluator = RegressionEvaluator(predictionCol="raw")
>>> evaluator.evaluate(dataset)
2.842...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
0.993...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
2.649...
"""
# a placeholder to make it appear in the generated doc
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (mse|rmse|r2|mae)")
@keyword_only
def __init__(self, predictionCol="prediction", labelCol="label",
metricName="rmse"):
"""
__init__(self, predictionCol="prediction", labelCol="label", \
metricName="rmse")
"""
super(RegressionEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
#: param for metric name in evaluation (mse|rmse|r2|mae)
self.metricName = Param(self, "metricName",
"metric name in evaluation (mse|rmse|r2|mae)")
self._setDefault(predictionCol="prediction", labelCol="label",
metricName="rmse")
kwargs = self.__init__._input_kwargs
self._set(**kwargs)
def setMetricName(self, value):
"""
Sets the value of :py:attr:`metricName`.
"""
self._paramMap[self.metricName] = value
return self
def getMetricName(self):
"""
Gets the value of metricName or its default value.
"""
return self.getOrDefault(self.metricName)
@keyword_only
def setParams(self, predictionCol="prediction", labelCol="label",
metricName="rmse"):
"""
setParams(self, predictionCol="prediction", labelCol="label",
metricName="rmse")
Sets params for regression evaluator.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
from pyspark.context import SparkContext from pyspark.context import SparkContext
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment