Skip to content
Snippets Groups Projects
Commit 8f11c611 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7535] [.0] [MLLIB] Audit the pipeline APIs for 1.4

Some changes to the pipeilne APIs:

1. Estimator/Transformer/ doesn’t need to extend Params since PipelineStage already does.
1. Move Evaluator to ml.evaluation.
1. Mention larger metric values are better.
1. PipelineModel doc. “compiled” -> “fitted”
1. Hide object PolynomialExpansion.
1. Hide object VectorAssembler.
1. Word2Vec.minCount (and other) -> group param
1. ParamValidators -> DeveloperApi
1. Hide MetadataUtils/SchemaUtils.

jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #6322 from mengxr/SPARK-7535.0 and squashes the following commits:

9e9c7da [Xiangrui Meng] move JavaEvaluator to ml.evaluation as well
e179480 [Xiangrui Meng] move Evaluation to ml.evaluation in PySpark
08ef61f [Xiangrui Meng] update pipieline APIs
parent e4136ea6
No related branches found
No related tags found
No related merge requests found
Showing
with 84 additions and 80 deletions
...@@ -28,7 +28,7 @@ import org.apache.spark.sql.DataFrame ...@@ -28,7 +28,7 @@ import org.apache.spark.sql.DataFrame
* Abstract class for estimators that fit models to data. * Abstract class for estimators that fit models to data.
*/ */
@AlphaComponent @AlphaComponent
abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { abstract class Estimator[M <: Model[M]] extends PipelineStage {
/** /**
* Fits a single model to the input data with optional parameters. * Fits a single model to the input data with optional parameters.
......
...@@ -170,7 +170,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] { ...@@ -170,7 +170,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] {
/** /**
* :: AlphaComponent :: * :: AlphaComponent ::
* Represents a compiled pipeline. * Represents a fitted pipeline.
*/ */
@AlphaComponent @AlphaComponent
class PipelineModel private[ml] ( class PipelineModel private[ml] (
......
...@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ ...@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
* Abstract class for transformers that transform one dataset into another. * Abstract class for transformers that transform one dataset into another.
*/ */
@AlphaComponent @AlphaComponent
abstract class Transformer extends PipelineStage with Params { abstract class Transformer extends PipelineStage {
/** /**
* Transforms the dataset with optional parameters * Transforms the dataset with optional parameters
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
package org.apache.spark.ml.evaluation package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.AlphaComponent import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Evaluator import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.spark.ml package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.AlphaComponent import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.ml.param.{ParamMap, Params}
...@@ -29,7 +29,7 @@ import org.apache.spark.sql.DataFrame ...@@ -29,7 +29,7 @@ import org.apache.spark.sql.DataFrame
abstract class Evaluator extends Params { abstract class Evaluator extends Params {
/** /**
* Evaluates the output. * Evaluates model output and returns a scalar metric (larger is better).
* *
* @param dataset a dataset that contains labels/observations and predictions. * @param dataset a dataset that contains labels/observations and predictions.
* @param paramMap parameter map that specifies the input columns and output metrics * @param paramMap parameter map that specifies the input columns and output metrics
......
...@@ -75,7 +75,7 @@ class PolynomialExpansion(override val uid: String) ...@@ -75,7 +75,7 @@ class PolynomialExpansion(override val uid: String)
* To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the
* current index and increment it properly for sparse input. * current index and increment it properly for sparse input.
*/ */
object PolynomialExpansion { private[feature] object PolynomialExpansion {
private def choose(n: Int, k: Int): Int = { private def choose(n: Int, k: Int): Int = {
Range(n, n - k, -1).product / Range(k, 1, -1).product Range(n, n - k, -1).product / Range(k, 1, -1).product
......
...@@ -78,8 +78,7 @@ class VectorAssembler(override val uid: String) ...@@ -78,8 +78,7 @@ class VectorAssembler(override val uid: String)
} }
} }
@AlphaComponent private object VectorAssembler {
object VectorAssembler {
private[feature] def assemble(vv: Any*): Vector = { private[feature] def assemble(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int] val indices = ArrayBuilder.make[Int]
......
...@@ -37,6 +37,7 @@ private[feature] trait Word2VecBase extends Params ...@@ -37,6 +37,7 @@ private[feature] trait Word2VecBase extends Params
/** /**
* The dimension of the code that you want to transform from words. * The dimension of the code that you want to transform from words.
* @group param
*/ */
final val vectorSize = new IntParam( final val vectorSize = new IntParam(
this, "vectorSize", "the dimension of codes after transforming from words") this, "vectorSize", "the dimension of codes after transforming from words")
...@@ -47,6 +48,7 @@ private[feature] trait Word2VecBase extends Params ...@@ -47,6 +48,7 @@ private[feature] trait Word2VecBase extends Params
/** /**
* Number of partitions for sentences of words. * Number of partitions for sentences of words.
* @group param
*/ */
final val numPartitions = new IntParam( final val numPartitions = new IntParam(
this, "numPartitions", "number of partitions for sentences of words") this, "numPartitions", "number of partitions for sentences of words")
...@@ -58,6 +60,7 @@ private[feature] trait Word2VecBase extends Params ...@@ -58,6 +60,7 @@ private[feature] trait Word2VecBase extends Params
/** /**
* The minimum number of times a token must appear to be included in the word2vec model's * The minimum number of times a token must appear to be included in the word2vec model's
* vocabulary. * vocabulary.
* @group param
*/ */
final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " + final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " +
"appear to be included in the word2vec model's vocabulary") "appear to be included in the word2vec model's vocabulary")
......
...@@ -24,7 +24,7 @@ import scala.annotation.varargs ...@@ -24,7 +24,7 @@ import scala.annotation.varargs
import scala.collection.mutable import scala.collection.mutable
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.spark.annotation.AlphaComponent import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
/** /**
...@@ -92,9 +92,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali ...@@ -92,9 +92,11 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
} }
/** /**
* :: DeveloperApi ::
* Factory methods for common validation functions for [[Param.isValid]]. * Factory methods for common validation functions for [[Param.isValid]].
* The numerical methods only support Int, Long, Float, and Double. * The numerical methods only support Int, Long, Float, and Double.
*/ */
@DeveloperApi
object ParamValidators { object ParamValidators {
/** (private[param]) Default validation always return true */ /** (private[param]) Default validation always return true */
...@@ -529,11 +531,13 @@ trait Params extends Identifiable with Serializable { ...@@ -529,11 +531,13 @@ trait Params extends Identifiable with Serializable {
} }
/** /**
* :: DeveloperApi ::
* Java-friendly wrapper for [[Params]]. * Java-friendly wrapper for [[Params]].
* Java developers who need to extend [[Params]] should use this class instead. * Java developers who need to extend [[Params]] should use this class instead.
* If you need to extend a abstract class which already extends [[Params]], then that abstract * If you need to extend a abstract class which already extends [[Params]], then that abstract
* class should be Java-friendly as well. * class should be Java-friendly as well.
*/ */
@DeveloperApi
abstract class JavaParams extends Params abstract class JavaParams extends Params
/** /**
......
...@@ -22,6 +22,7 @@ import com.github.fommil.netlib.F2jBLAS ...@@ -22,6 +22,7 @@ import com.github.fommil.netlib.F2jBLAS
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._ import org.apache.spark.ml._
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.util.MLUtils
......
...@@ -19,18 +19,14 @@ package org.apache.spark.ml.util ...@@ -19,18 +19,14 @@ package org.apache.spark.ml.util
import scala.collection.immutable.HashMap import scala.collection.immutable.HashMap
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.attribute._ import org.apache.spark.ml.attribute._
import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructField
/** /**
* :: Experimental ::
*
* Helper utilities for tree-based algorithms * Helper utilities for tree-based algorithms
*/ */
@Experimental private[spark] object MetadataUtils {
object MetadataUtils {
/** /**
* Examine a schema to identify the number of classes in a label column. * Examine a schema to identify the number of classes in a label column.
......
...@@ -17,15 +17,13 @@ ...@@ -17,15 +17,13 @@
package org.apache.spark.ml.util package org.apache.spark.ml.util
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.sql.types.{DataType, StructField, StructType}
/** /**
* :: DeveloperApi ::
* Utils for handling schemas. * Utils for handling schemas.
*/ */
@DeveloperApi private[spark] object SchemaUtils {
object SchemaUtils {
// TODO: Move the utility methods to SQL. // TODO: Move the utility methods to SQL.
......
...@@ -15,6 +15,6 @@ ...@@ -15,6 +15,6 @@
# limitations under the License. # limitations under the License.
# #
from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel, Evaluator from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel
__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel", "Evaluator"] __all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
...@@ -15,13 +15,72 @@ ...@@ -15,13 +15,72 @@
# limitations under the License. # limitations under the License.
# #
from pyspark.ml.wrapper import JavaEvaluator from abc import abstractmethod, ABCMeta
from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.param import Param, Params from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only from pyspark.ml.util import keyword_only
from pyspark.mllib.common import inherit_doc from pyspark.mllib.common import inherit_doc
__all__ = ['BinaryClassificationEvaluator'] __all__ = ['Evaluator', 'BinaryClassificationEvaluator']
@inherit_doc
class Evaluator(Params):
"""
Base class for evaluators that compute metrics from predictions.
"""
__metaclass__ = ABCMeta
@abstractmethod
def _evaluate(self, dataset):
"""
Evaluates the output.
:param dataset: a dataset that contains labels/observations and
predictions
:return: metric
"""
raise NotImplementedError()
def evaluate(self, dataset, params={}):
"""
Evaluates the output with optional parameters.
:param dataset: a dataset that contains labels/observations and
predictions
:param params: an optional param map that overrides embedded
params
:return: metric
"""
if isinstance(params, dict):
if params:
return self.copy(params)._evaluate(dataset)
else:
return self._evaluate(dataset)
else:
raise ValueError("Params must be a param map but got %s." % type(params))
@inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper):
"""
Base class for :py:class:`Evaluator`s that wrap Java/Scala
implementations.
"""
__metaclass__ = ABCMeta
def _evaluate(self, dataset):
"""
Evaluates the output.
:param dataset: a dataset that contains labels/observations and predictions.
:return: evaluation metric
"""
self._transfer_params_to_java()
return self._java_obj.evaluate(dataset._jdf)
@inherit_doc @inherit_doc
......
...@@ -219,40 +219,3 @@ class PipelineModel(Model): ...@@ -219,40 +219,3 @@ class PipelineModel(Model):
def copy(self, extra={}): def copy(self, extra={}):
stages = [stage.copy(extra) for stage in self.stages] stages = [stage.copy(extra) for stage in self.stages]
return PipelineModel(stages) return PipelineModel(stages)
class Evaluator(Params):
"""
Base class for evaluators that compute metrics from predictions.
"""
__metaclass__ = ABCMeta
@abstractmethod
def _evaluate(self, dataset):
"""
Evaluates the output.
:param dataset: a dataset that contains labels/observations and
predictions
:return: metric
"""
raise NotImplementedError()
def evaluate(self, dataset, params={}):
"""
Evaluates the output with optional parameters.
:param dataset: a dataset that contains labels/observations and
predictions
:param params: an optional param map that overrides embedded
params
:return: metric
"""
if isinstance(params, dict):
if params:
return self.copy(params)._evaluate(dataset)
else:
return self._evaluate(dataset)
else:
raise ValueError("Params must be a param map but got %s." % type(params))
...@@ -20,7 +20,7 @@ from abc import ABCMeta ...@@ -20,7 +20,7 @@ from abc import ABCMeta
from pyspark import SparkContext from pyspark import SparkContext
from pyspark.sql import DataFrame from pyspark.sql import DataFrame
from pyspark.ml.param import Params from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model from pyspark.ml.pipeline import Estimator, Transformer, Model
from pyspark.mllib.common import inherit_doc, _java2py, _py2java from pyspark.mllib.common import inherit_doc, _java2py, _py2java
...@@ -185,22 +185,3 @@ class JavaModel(Model, JavaTransformer): ...@@ -185,22 +185,3 @@ class JavaModel(Model, JavaTransformer):
sc = SparkContext._active_spark_context sc = SparkContext._active_spark_context
java_args = [_py2java(sc, arg) for arg in args] java_args = [_py2java(sc, arg) for arg in args]
return _java2py(sc, m(*java_args)) return _java2py(sc, m(*java_args))
@inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper):
"""
Base class for :py:class:`Evaluator`s that wrap Java/Scala
implementations.
"""
__metaclass__ = ABCMeta
def _evaluate(self, dataset):
"""
Evaluates the output.
:param dataset: a dataset that contains labels/observations and predictions.
:return: evaluation metric
"""
self._transfer_params_to_java()
return self._java_obj.evaluate(dataset._jdf)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment