Skip to content
Snippets Groups Projects
Commit 7d05a624 authored by Takahashi Hiroshi's avatar Takahashi Hiroshi Committed by Xiangrui Meng
Browse files

[SPARK-10259][ML] Add @since annotation to ml.classification

Add since annotation to ml.classification

Author: Takahashi Hiroshi <takahashi.hiroshi@lab.ntt.co.jp>

Closes #8534 from taishi-oss/issue10259.
parent 73896588
No related branches found
No related tags found
No related merge requests found
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
package org.apache.spark.ml.classification package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.tree.impl.RandomForest
...@@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame ...@@ -36,32 +36,44 @@ import org.apache.spark.sql.DataFrame
* It supports both binary and multiclass labels, as well as both continuous and categorical * It supports both binary and multiclass labels, as well as both continuous and categorical
* features. * features.
*/ */
@Since("1.4.0")
@Experimental @Experimental
final class DecisionTreeClassifier(override val uid: String) final class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams { with DecisionTreeParams with TreeClassifierParams {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("dtc")) def this() = this(Identifiable.randomUID("dtc"))
// Override parameter setters from parent trait for Java API compatibility. // Override parameter setters from parent trait for Java API compatibility.
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value) override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value) super.setMinInstancesPerNode(value)
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
@Since("1.4.0")
override def setImpurity(value: String): this.type = super.setImpurity(value) override def setImpurity(value: String): this.type = super.setImpurity(value)
@Since("1.6.0")
override def setSeed(value: Long): this.type = super.setSeed(value) override def setSeed(value: Long): this.type = super.setSeed(value)
override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
...@@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String) ...@@ -89,12 +101,15 @@ final class DecisionTreeClassifier(override val uid: String)
subsamplingRate = 1.0) subsamplingRate = 1.0)
} }
@Since("1.4.1")
override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra) override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
} }
@Since("1.4.0")
@Experimental @Experimental
object DecisionTreeClassifier { object DecisionTreeClassifier {
/** Accessor for supported impurities: entropy, gini */ /** Accessor for supported impurities: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
} }
...@@ -104,12 +119,13 @@ object DecisionTreeClassifier { ...@@ -104,12 +119,13 @@ object DecisionTreeClassifier {
* It supports both binary and multiclass labels, as well as both continuous and categorical * It supports both binary and multiclass labels, as well as both continuous and categorical
* features. * features.
*/ */
@Since("1.4.0")
@Experimental @Experimental
final class DecisionTreeClassificationModel private[ml] ( final class DecisionTreeClassificationModel private[ml] (
override val uid: String, @Since("1.4.0")override val uid: String,
override val rootNode: Node, @Since("1.4.0")override val rootNode: Node,
override val numFeatures: Int, @Since("1.6.0")override val numFeatures: Int,
override val numClasses: Int) @Since("1.5.0")override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable { with DecisionTreeModel with Serializable {
...@@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] ( ...@@ -142,11 +158,13 @@ final class DecisionTreeClassificationModel private[ml] (
} }
} }
@Since("1.4.0")
override def copy(extra: ParamMap): DecisionTreeClassificationModel = { override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra)
.setParent(parent) .setParent(parent)
} }
@Since("1.4.0")
override def toString: String = { override def toString: String = {
s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes"
} }
......
...@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification ...@@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel
...@@ -44,36 +44,47 @@ import org.apache.spark.sql.types.DoubleType ...@@ -44,36 +44,47 @@ import org.apache.spark.sql.types.DoubleType
* It supports binary labels, as well as both continuous and categorical features. * It supports binary labels, as well as both continuous and categorical features.
* Note: Multiclass labels are not currently supported. * Note: Multiclass labels are not currently supported.
*/ */
@Since("1.4.0")
@Experimental @Experimental
final class GBTClassifier(override val uid: String) final class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel] extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTParams with TreeClassifierParams with Logging { with GBTParams with TreeClassifierParams with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("gbtc")) def this() = this(Identifiable.randomUID("gbtc"))
// Override parameter setters from parent trait for Java API compatibility. // Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeClassifierParams: // Parameters from TreeClassifierParams:
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value) override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value) super.setMinInstancesPerNode(value)
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
/** /**
* The impurity setting is ignored for GBT models. * The impurity setting is ignored for GBT models.
* Individual trees are built using impurity "Variance." * Individual trees are built using impurity "Variance."
*/ */
@Since("1.4.0")
override def setImpurity(value: String): this.type = { override def setImpurity(value: String): this.type = {
logWarning("GBTClassifier.setImpurity should NOT be used") logWarning("GBTClassifier.setImpurity should NOT be used")
this this
...@@ -81,8 +92,10 @@ final class GBTClassifier(override val uid: String) ...@@ -81,8 +92,10 @@ final class GBTClassifier(override val uid: String)
// Parameters from TreeEnsembleParams: // Parameters from TreeEnsembleParams:
@Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
override def setSeed(value: Long): this.type = { override def setSeed(value: Long): this.type = {
logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.") logWarning("The 'seed' parameter is currently ignored by Gradient Boosting.")
super.setSeed(value) super.setSeed(value)
...@@ -90,8 +103,10 @@ final class GBTClassifier(override val uid: String) ...@@ -90,8 +103,10 @@ final class GBTClassifier(override val uid: String)
// Parameters from GBTParams: // Parameters from GBTParams:
@Since("1.4.0")
override def setMaxIter(value: Int): this.type = super.setMaxIter(value) override def setMaxIter(value: Int): this.type = super.setMaxIter(value)
@Since("1.4.0")
override def setStepSize(value: Double): this.type = super.setStepSize(value) override def setStepSize(value: Double): this.type = super.setStepSize(value)
// Parameters for GBTClassifier: // Parameters for GBTClassifier:
...@@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String) ...@@ -102,6 +117,7 @@ final class GBTClassifier(override val uid: String)
* (default = logistic) * (default = logistic)
* @group param * @group param
*/ */
@Since("1.4.0")
val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" +
" tries to minimize (case-insensitive). Supported options:" + " tries to minimize (case-insensitive). Supported options:" +
s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", s" ${GBTClassifier.supportedLossTypes.mkString(", ")}",
...@@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String) ...@@ -110,9 +126,11 @@ final class GBTClassifier(override val uid: String)
setDefault(lossType -> "logistic") setDefault(lossType -> "logistic")
/** @group setParam */ /** @group setParam */
@Since("1.4.0")
def setLossType(value: String): this.type = set(lossType, value) def setLossType(value: String): this.type = set(lossType, value)
/** @group getParam */ /** @group getParam */
@Since("1.4.0")
def getLossType: String = $(lossType).toLowerCase def getLossType: String = $(lossType).toLowerCase
/** (private[ml]) Convert new loss to old loss. */ /** (private[ml]) Convert new loss to old loss. */
...@@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String) ...@@ -145,13 +163,16 @@ final class GBTClassifier(override val uid: String)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures) GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
} }
@Since("1.4.1")
override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra) override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
} }
@Since("1.4.0")
@Experimental @Experimental
object GBTClassifier { object GBTClassifier {
// The losses below should be lowercase. // The losses below should be lowercase.
/** Accessor for supported loss settings: logistic */ /** Accessor for supported loss settings: logistic */
@Since("1.4.0")
final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase)
} }
...@@ -164,12 +185,13 @@ object GBTClassifier { ...@@ -164,12 +185,13 @@ object GBTClassifier {
* @param _trees Decision trees in the ensemble. * @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble.
*/ */
@Since("1.6.0")
@Experimental @Experimental
final class GBTClassificationModel private[ml]( final class GBTClassificationModel private[ml](
override val uid: String, @Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel], private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double], private val _treeWeights: Array[Double],
override val numFeatures: Int) @Since("1.6.0") override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel] extends PredictionModel[Vector, GBTClassificationModel]
with TreeEnsembleModel with Serializable { with TreeEnsembleModel with Serializable {
...@@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml]( ...@@ -182,11 +204,14 @@ final class GBTClassificationModel private[ml](
* @param _trees Decision trees in the ensemble. * @param _trees Decision trees in the ensemble.
* @param _treeWeights Weights for the decision trees in the ensemble. * @param _treeWeights Weights for the decision trees in the ensemble.
*/ */
@Since("1.6.0")
def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) = def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double]) =
this(uid, _trees, _treeWeights, -1) this(uid, _trees, _treeWeights, -1)
@Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: DataFrame): DataFrame = { override protected def transformImpl(dataset: DataFrame): DataFrame = {
...@@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml]( ...@@ -205,11 +230,13 @@ final class GBTClassificationModel private[ml](
if (prediction > 0.0) 1.0 else 0.0 if (prediction > 0.0) 1.0 else 0.0
} }
@Since("1.4.0")
override def copy(extra: ParamMap): GBTClassificationModel = { override def copy(extra: ParamMap): GBTClassificationModel = {
copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
extra).setParent(parent) extra).setParent(parent)
} }
@Since("1.4.0")
override def toString: String = { override def toString: String = {
s"GBTClassificationModel (uid=$uid) with $numTrees trees" s"GBTClassificationModel (uid=$uid) with $numTrees trees"
} }
......
...@@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, ...@@ -24,7 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkException} import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
...@@ -154,11 +154,14 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas ...@@ -154,11 +154,14 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
* Currently, this class only supports binary classification. It will support multiclass * Currently, this class only supports binary classification. It will support multiclass
* in the future. * in the future.
*/ */
@Since("1.2.0")
@Experimental @Experimental
class LogisticRegression(override val uid: String) class LogisticRegression @Since("1.2.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
with LogisticRegressionParams with DefaultParamsWritable with Logging { with LogisticRegressionParams with DefaultParamsWritable with Logging {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("logreg")) def this() = this(Identifiable.randomUID("logreg"))
/** /**
...@@ -166,6 +169,7 @@ class LogisticRegression(override val uid: String) ...@@ -166,6 +169,7 @@ class LogisticRegression(override val uid: String)
* Default is 0.0. * Default is 0.0.
* @group setParam * @group setParam
*/ */
@Since("1.2.0")
def setRegParam(value: Double): this.type = set(regParam, value) def setRegParam(value: Double): this.type = set(regParam, value)
setDefault(regParam -> 0.0) setDefault(regParam -> 0.0)
...@@ -176,6 +180,7 @@ class LogisticRegression(override val uid: String) ...@@ -176,6 +180,7 @@ class LogisticRegression(override val uid: String)
* Default is 0.0 which is an L2 penalty. * Default is 0.0 which is an L2 penalty.
* @group setParam * @group setParam
*/ */
@Since("1.4.0")
def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
setDefault(elasticNetParam -> 0.0) setDefault(elasticNetParam -> 0.0)
...@@ -184,6 +189,7 @@ class LogisticRegression(override val uid: String) ...@@ -184,6 +189,7 @@ class LogisticRegression(override val uid: String)
* Default is 100. * Default is 100.
* @group setParam * @group setParam
*/ */
@Since("1.2.0")
def setMaxIter(value: Int): this.type = set(maxIter, value) def setMaxIter(value: Int): this.type = set(maxIter, value)
setDefault(maxIter -> 100) setDefault(maxIter -> 100)
...@@ -193,6 +199,7 @@ class LogisticRegression(override val uid: String) ...@@ -193,6 +199,7 @@ class LogisticRegression(override val uid: String)
* Default is 1E-6. * Default is 1E-6.
* @group setParam * @group setParam
*/ */
@Since("1.4.0")
def setTol(value: Double): this.type = set(tol, value) def setTol(value: Double): this.type = set(tol, value)
setDefault(tol -> 1E-6) setDefault(tol -> 1E-6)
...@@ -201,6 +208,7 @@ class LogisticRegression(override val uid: String) ...@@ -201,6 +208,7 @@ class LogisticRegression(override val uid: String)
* Default is true. * Default is true.
* @group setParam * @group setParam
*/ */
@Since("1.4.0")
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
setDefault(fitIntercept -> true) setDefault(fitIntercept -> true)
...@@ -213,11 +221,14 @@ class LogisticRegression(override val uid: String) ...@@ -213,11 +221,14 @@ class LogisticRegression(override val uid: String)
* Default is true. * Default is true.
* @group setParam * @group setParam
*/ */
@Since("1.5.0")
def setStandardization(value: Boolean): this.type = set(standardization, value) def setStandardization(value: Boolean): this.type = set(standardization, value)
setDefault(standardization -> true) setDefault(standardization -> true)
@Since("1.5.0")
override def setThreshold(value: Double): this.type = super.setThreshold(value) override def setThreshold(value: Double): this.type = super.setThreshold(value)
@Since("1.5.0")
override def getThreshold: Double = super.getThreshold override def getThreshold: Double = super.getThreshold
/** /**
...@@ -226,11 +237,14 @@ class LogisticRegression(override val uid: String) ...@@ -226,11 +237,14 @@ class LogisticRegression(override val uid: String)
* Default is empty, so all instances have weight one. * Default is empty, so all instances have weight one.
* @group setParam * @group setParam
*/ */
@Since("1.6.0")
def setWeightCol(value: String): this.type = set(weightCol, value) def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "") setDefault(weightCol -> "")
@Since("1.5.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
@Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds override def getThresholds: Array[Double] = super.getThresholds
override protected def train(dataset: DataFrame): LogisticRegressionModel = { override protected def train(dataset: DataFrame): LogisticRegressionModel = {
...@@ -384,11 +398,14 @@ class LogisticRegression(override val uid: String) ...@@ -384,11 +398,14 @@ class LogisticRegression(override val uid: String)
model.setSummary(logRegSummary) model.setSummary(logRegSummary)
} }
@Since("1.4.0")
override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra) override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
} }
@Since("1.6.0")
object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
@Since("1.6.0")
override def load(path: String): LogisticRegression = super.load(path) override def load(path: String): LogisticRegression = super.load(path)
} }
...@@ -396,23 +413,28 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] { ...@@ -396,23 +413,28 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
* :: Experimental :: * :: Experimental ::
* Model produced by [[LogisticRegression]]. * Model produced by [[LogisticRegression]].
*/ */
@Since("1.4.0")
@Experimental @Experimental
class LogisticRegressionModel private[ml] ( class LogisticRegressionModel private[ml] (
override val uid: String, @Since("1.4.0") override val uid: String,
val coefficients: Vector, @Since("1.6.0") val coefficients: Vector,
val intercept: Double) @Since("1.3.0") val intercept: Double)
extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
with LogisticRegressionParams with MLWritable { with LogisticRegressionParams with MLWritable {
@deprecated("Use coefficients instead.", "1.6.0") @deprecated("Use coefficients instead.", "1.6.0")
def weights: Vector = coefficients def weights: Vector = coefficients
@Since("1.5.0")
override def setThreshold(value: Double): this.type = super.setThreshold(value) override def setThreshold(value: Double): this.type = super.setThreshold(value)
@Since("1.5.0")
override def getThreshold: Double = super.getThreshold override def getThreshold: Double = super.getThreshold
@Since("1.5.0")
override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
@Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds override def getThresholds: Array[Double] = super.getThresholds
/** Margin (rawPrediction) for class label 1. For binary classification only. */ /** Margin (rawPrediction) for class label 1. For binary classification only. */
...@@ -426,8 +448,10 @@ class LogisticRegressionModel private[ml] ( ...@@ -426,8 +448,10 @@ class LogisticRegressionModel private[ml] (
1.0 / (1.0 + math.exp(-m)) 1.0 / (1.0 + math.exp(-m))
} }
@Since("1.6.0")
override val numFeatures: Int = coefficients.size override val numFeatures: Int = coefficients.size
@Since("1.3.0")
override val numClasses: Int = 2 override val numClasses: Int = 2
private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
...@@ -436,6 +460,7 @@ class LogisticRegressionModel private[ml] ( ...@@ -436,6 +460,7 @@ class LogisticRegressionModel private[ml] (
* Gets summary of model on training set. An exception is * Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`. * thrown if `trainingSummary == None`.
*/ */
@Since("1.5.0")
def summary: LogisticRegressionTrainingSummary = trainingSummary match { def summary: LogisticRegressionTrainingSummary = trainingSummary match {
case Some(summ) => summ case Some(summ) => summ
case None => case None =>
...@@ -451,6 +476,7 @@ class LogisticRegressionModel private[ml] ( ...@@ -451,6 +476,7 @@ class LogisticRegressionModel private[ml] (
} }
/** Indicates whether a training summary exists for this model instance. */ /** Indicates whether a training summary exists for this model instance. */
@Since("1.5.0")
def hasSummary: Boolean = trainingSummary.isDefined def hasSummary: Boolean = trainingSummary.isDefined
/** /**
...@@ -493,6 +519,7 @@ class LogisticRegressionModel private[ml] ( ...@@ -493,6 +519,7 @@ class LogisticRegressionModel private[ml] (
Vectors.dense(-m, m) Vectors.dense(-m, m)
} }
@Since("1.4.0")
override def copy(extra: ParamMap): LogisticRegressionModel = { override def copy(extra: ParamMap): LogisticRegressionModel = {
val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra) val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
...@@ -710,12 +737,13 @@ sealed trait LogisticRegressionSummary extends Serializable { ...@@ -710,12 +737,13 @@ sealed trait LogisticRegressionSummary extends Serializable {
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
*/ */
@Experimental @Experimental
@Since("1.5.0")
class BinaryLogisticRegressionTrainingSummary private[classification] ( class BinaryLogisticRegressionTrainingSummary private[classification] (
predictions: DataFrame, @Since("1.5.0") predictions: DataFrame,
probabilityCol: String, @Since("1.5.0") probabilityCol: String,
labelCol: String, @Since("1.5.0") labelCol: String,
featuresCol: String, @Since("1.6.0") featuresCol: String,
val objectiveHistory: Array[Double]) @Since("1.5.0") val objectiveHistory: Array[Double])
extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol) extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol, featuresCol)
with LogisticRegressionTrainingSummary { with LogisticRegressionTrainingSummary {
...@@ -731,11 +759,13 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( ...@@ -731,11 +759,13 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
* @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param featuresCol field in "predictions" which gives the features of each instance as a vector.
*/ */
@Experimental @Experimental
@Since("1.5.0")
class BinaryLogisticRegressionSummary private[classification] ( class BinaryLogisticRegressionSummary private[classification] (
@transient override val predictions: DataFrame, @Since("1.5.0") @transient override val predictions: DataFrame,
override val probabilityCol: String, @Since("1.5.0") override val probabilityCol: String,
override val labelCol: String, @Since("1.5.0") override val labelCol: String,
override val featuresCol: String) extends LogisticRegressionSummary { @Since("1.6.0") override val featuresCol: String) extends LogisticRegressionSummary {
private val sqlContext = predictions.sqlContext private val sqlContext = predictions.sqlContext
import sqlContext.implicits._ import sqlContext.implicits._
...@@ -760,6 +790,7 @@ class BinaryLogisticRegressionSummary private[classification] ( ...@@ -760,6 +790,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* This will change in later Spark versions. * This will change in later Spark versions.
* @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
*/ */
@Since("1.5.0")
@transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR") @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
/** /**
...@@ -768,6 +799,7 @@ class BinaryLogisticRegressionSummary private[classification] ( ...@@ -768,6 +799,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions. * This will change in later Spark versions.
*/ */
@Since("1.5.0")
lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC() lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
/** /**
...@@ -777,6 +809,7 @@ class BinaryLogisticRegressionSummary private[classification] ( ...@@ -777,6 +809,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions. * This will change in later Spark versions.
*/ */
@Since("1.5.0")
@transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision") @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
/** /**
...@@ -785,6 +818,7 @@ class BinaryLogisticRegressionSummary private[classification] ( ...@@ -785,6 +818,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions. * This will change in later Spark versions.
*/ */
@Since("1.5.0")
@transient lazy val fMeasureByThreshold: DataFrame = { @transient lazy val fMeasureByThreshold: DataFrame = {
binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure") binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
} }
...@@ -797,6 +831,7 @@ class BinaryLogisticRegressionSummary private[classification] ( ...@@ -797,6 +831,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions. * This will change in later Spark versions.
*/ */
@Since("1.5.0")
@transient lazy val precisionByThreshold: DataFrame = { @transient lazy val precisionByThreshold: DataFrame = {
binaryMetrics.precisionByThreshold().toDF("threshold", "precision") binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
} }
...@@ -809,6 +844,7 @@ class BinaryLogisticRegressionSummary private[classification] ( ...@@ -809,6 +844,7 @@ class BinaryLogisticRegressionSummary private[classification] (
* Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]]. * Note: This ignores instance weights (setting all to 1.0) from [[LogisticRegression.weightCol]].
* This will change in later Spark versions. * This will change in later Spark versions.
*/ */
@Since("1.5.0")
@transient lazy val recallByThreshold: DataFrame = { @transient lazy val recallByThreshold: DataFrame = {
binaryMetrics.recallByThreshold().toDF("threshold", "recall") binaryMetrics.recallByThreshold().toDF("threshold", "recall")
} }
......
...@@ -19,7 +19,7 @@ package org.apache.spark.ml.classification ...@@ -19,7 +19,7 @@ package org.apache.spark.ml.classification
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed}
import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor}
import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap}
...@@ -104,19 +104,23 @@ private object LabelConverter { ...@@ -104,19 +104,23 @@ private object LabelConverter {
* Each layer has sigmoid activation function, output layer has softmax. * Each layer has sigmoid activation function, output layer has softmax.
* Number of inputs has to be equal to the size of feature vectors. * Number of inputs has to be equal to the size of feature vectors.
* Number of outputs has to be equal to the total number of labels. * Number of outputs has to be equal to the total number of labels.
*
*/ */
@Since("1.5.0")
@Experimental @Experimental
class MultilayerPerceptronClassifier(override val uid: String) class MultilayerPerceptronClassifier @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel] extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
with MultilayerPerceptronParams { with MultilayerPerceptronParams {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("mlpc")) def this() = this(Identifiable.randomUID("mlpc"))
/** @group setParam */ /** @group setParam */
@Since("1.5.0")
def setLayers(value: Array[Int]): this.type = set(layers, value) def setLayers(value: Array[Int]): this.type = set(layers, value)
/** @group setParam */ /** @group setParam */
@Since("1.5.0")
def setBlockSize(value: Int): this.type = set(blockSize, value) def setBlockSize(value: Int): this.type = set(blockSize, value)
/** /**
...@@ -124,6 +128,7 @@ class MultilayerPerceptronClassifier(override val uid: String) ...@@ -124,6 +128,7 @@ class MultilayerPerceptronClassifier(override val uid: String)
* Default is 100. * Default is 100.
* @group setParam * @group setParam
*/ */
@Since("1.5.0")
def setMaxIter(value: Int): this.type = set(maxIter, value) def setMaxIter(value: Int): this.type = set(maxIter, value)
/** /**
...@@ -132,14 +137,17 @@ class MultilayerPerceptronClassifier(override val uid: String) ...@@ -132,14 +137,17 @@ class MultilayerPerceptronClassifier(override val uid: String)
* Default is 1E-4. * Default is 1E-4.
* @group setParam * @group setParam
*/ */
@Since("1.5.0")
def setTol(value: Double): this.type = set(tol, value) def setTol(value: Double): this.type = set(tol, value)
/** /**
* Set the seed for weights initialization. * Set the seed for weights initialization.
* @group setParam * @group setParam
*/ */
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value) def setSeed(value: Long): this.type = set(seed, value)
@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
/** /**
...@@ -173,14 +181,16 @@ class MultilayerPerceptronClassifier(override val uid: String) ...@@ -173,14 +181,16 @@ class MultilayerPerceptronClassifier(override val uid: String)
* @param weights vector of initial weights for the model that consists of the weights of layers * @param weights vector of initial weights for the model that consists of the weights of layers
* @return prediction model * @return prediction model
*/ */
@Since("1.5.0")
@Experimental @Experimental
class MultilayerPerceptronClassificationModel private[ml] ( class MultilayerPerceptronClassificationModel private[ml] (
override val uid: String, @Since("1.5.0") override val uid: String,
val layers: Array[Int], @Since("1.5.0") val layers: Array[Int],
val weights: Vector) @Since("1.5.0") val weights: Vector)
extends PredictionModel[Vector, MultilayerPerceptronClassificationModel] extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
with Serializable { with Serializable {
@Since("1.6.0")
override val numFeatures: Int = layers.head override val numFeatures: Int = layers.head
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
...@@ -200,6 +210,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( ...@@ -200,6 +210,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
LabelConverter.decodeLabel(mlpModel.predict(features)) LabelConverter.decodeLabel(mlpModel.predict(features))
} }
@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = { override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra) copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
} }
......
...@@ -72,11 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams { ...@@ -72,11 +72,14 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]). * ([[http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html]]).
* The input feature values must be nonnegative. * The input feature values must be nonnegative.
*/ */
@Since("1.5.0")
@Experimental @Experimental
class NaiveBayes(override val uid: String) class NaiveBayes @Since("1.5.0") (
@Since("1.5.0") override val uid: String)
extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel] extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams with DefaultParamsWritable { with NaiveBayesParams with DefaultParamsWritable {
@Since("1.5.0")
def this() = this(Identifiable.randomUID("nb")) def this() = this(Identifiable.randomUID("nb"))
/** /**
...@@ -84,6 +87,7 @@ class NaiveBayes(override val uid: String) ...@@ -84,6 +87,7 @@ class NaiveBayes(override val uid: String)
* Default is 1.0. * Default is 1.0.
* @group setParam * @group setParam
*/ */
@Since("1.5.0")
def setSmoothing(value: Double): this.type = set(smoothing, value) def setSmoothing(value: Double): this.type = set(smoothing, value)
setDefault(smoothing -> 1.0) setDefault(smoothing -> 1.0)
...@@ -93,6 +97,7 @@ class NaiveBayes(override val uid: String) ...@@ -93,6 +97,7 @@ class NaiveBayes(override val uid: String)
* Default is "multinomial" * Default is "multinomial"
* @group setParam * @group setParam
*/ */
@Since("1.5.0")
def setModelType(value: String): this.type = set(modelType, value) def setModelType(value: String): this.type = set(modelType, value)
setDefault(modelType -> OldNaiveBayes.Multinomial) setDefault(modelType -> OldNaiveBayes.Multinomial)
...@@ -102,6 +107,7 @@ class NaiveBayes(override val uid: String) ...@@ -102,6 +107,7 @@ class NaiveBayes(override val uid: String)
NaiveBayesModel.fromOld(oldModel, this) NaiveBayesModel.fromOld(oldModel, this)
} }
@Since("1.5.0")
override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra) override def copy(extra: ParamMap): NaiveBayes = defaultCopy(extra)
} }
...@@ -119,11 +125,12 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] { ...@@ -119,11 +125,12 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
* @param theta log of class conditional probabilities, whose dimension is C (number of classes) * @param theta log of class conditional probabilities, whose dimension is C (number of classes)
* by D (number of features) * by D (number of features)
*/ */
@Since("1.5.0")
@Experimental @Experimental
class NaiveBayesModel private[ml] ( class NaiveBayesModel private[ml] (
override val uid: String, @Since("1.5.0") override val uid: String,
val pi: Vector, @Since("1.5.0") val pi: Vector,
val theta: Matrix) @Since("1.5.0") val theta: Matrix)
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
with NaiveBayesParams with MLWritable { with NaiveBayesParams with MLWritable {
...@@ -148,8 +155,10 @@ class NaiveBayesModel private[ml] ( ...@@ -148,8 +155,10 @@ class NaiveBayesModel private[ml] (
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.") throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
} }
@Since("1.6.0")
override val numFeatures: Int = theta.numCols override val numFeatures: Int = theta.numCols
@Since("1.5.0")
override val numClasses: Int = pi.size override val numClasses: Int = pi.size
private def multinomialCalculation(features: Vector) = { private def multinomialCalculation(features: Vector) = {
...@@ -206,10 +215,12 @@ class NaiveBayesModel private[ml] ( ...@@ -206,10 +215,12 @@ class NaiveBayesModel private[ml] (
} }
} }
@Since("1.5.0")
override def copy(extra: ParamMap): NaiveBayesModel = { override def copy(extra: ParamMap): NaiveBayesModel = {
copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra) copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
} }
@Since("1.5.0")
override def toString: String = { override def toString: String = {
s"NaiveBayesModel (uid=$uid) with ${pi.size} classes" s"NaiveBayesModel (uid=$uid) with ${pi.size} classes"
} }
......
...@@ -21,7 +21,7 @@ import java.util.UUID ...@@ -21,7 +21,7 @@ import java.util.UUID
import scala.language.existentials import scala.language.existentials
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._ import org.apache.spark.ml._
import org.apache.spark.ml.attribute._ import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.{Param, ParamMap}
...@@ -70,17 +70,20 @@ private[ml] trait OneVsRestParams extends PredictorParams { ...@@ -70,17 +70,20 @@ private[ml] trait OneVsRestParams extends PredictorParams {
* The i-th model is produced by testing the i-th class (taking label 1) vs the rest * The i-th model is produced by testing the i-th class (taking label 1) vs the rest
* (taking label 0). * (taking label 0).
*/ */
@Since("1.4.0")
@Experimental @Experimental
final class OneVsRestModel private[ml] ( final class OneVsRestModel private[ml] (
override val uid: String, @Since("1.4.0") override val uid: String,
labelMetadata: Metadata, @Since("1.4.0") labelMetadata: Metadata,
val models: Array[_ <: ClassificationModel[_, _]]) @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]])
extends Model[OneVsRestModel] with OneVsRestParams { extends Model[OneVsRestModel] with OneVsRestParams {
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType)
} }
@Since("1.4.0")
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
// Check schema // Check schema
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
...@@ -134,6 +137,7 @@ final class OneVsRestModel private[ml] ( ...@@ -134,6 +137,7 @@ final class OneVsRestModel private[ml] (
.drop(accColName) .drop(accColName)
} }
@Since("1.4.1")
override def copy(extra: ParamMap): OneVsRestModel = { override def copy(extra: ParamMap): OneVsRestModel = {
val copied = new OneVsRestModel( val copied = new OneVsRestModel(
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
...@@ -150,30 +154,39 @@ final class OneVsRestModel private[ml] ( ...@@ -150,30 +154,39 @@ final class OneVsRestModel private[ml] (
* Each example is scored against all k models and the model with highest score * Each example is scored against all k models and the model with highest score
* is picked to label the example. * is picked to label the example.
*/ */
@Since("1.4.0")
@Experimental @Experimental
final class OneVsRest(override val uid: String) final class OneVsRest @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Estimator[OneVsRestModel] with OneVsRestParams { extends Estimator[OneVsRestModel] with OneVsRestParams {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("oneVsRest")) def this() = this(Identifiable.randomUID("oneVsRest"))
/** @group setParam */ /** @group setParam */
@Since("1.4.0")
def setClassifier(value: Classifier[_, _, _]): this.type = { def setClassifier(value: Classifier[_, _, _]): this.type = {
set(classifier, value.asInstanceOf[ClassifierType]) set(classifier, value.asInstanceOf[ClassifierType])
} }
/** @group setParam */ /** @group setParam */
@Since("1.5.0")
def setLabelCol(value: String): this.type = set(labelCol, value) def setLabelCol(value: String): this.type = set(labelCol, value)
/** @group setParam */ /** @group setParam */
@Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value) def setFeaturesCol(value: String): this.type = set(featuresCol, value)
/** @group setParam */ /** @group setParam */
@Since("1.5.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value)
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType)
} }
@Since("1.4.0")
override def fit(dataset: DataFrame): OneVsRestModel = { override def fit(dataset: DataFrame): OneVsRestModel = {
// determine number of classes either from metadata if provided, or via computation. // determine number of classes either from metadata if provided, or via computation.
val labelSchema = dataset.schema($(labelCol)) val labelSchema = dataset.schema($(labelCol))
...@@ -222,6 +235,7 @@ final class OneVsRest(override val uid: String) ...@@ -222,6 +235,7 @@ final class OneVsRest(override val uid: String)
copyValues(model) copyValues(model)
} }
@Since("1.4.1")
override def copy(extra: ParamMap): OneVsRest = { override def copy(extra: ParamMap): OneVsRest = {
val copied = defaultCopy(extra).asInstanceOf[OneVsRest] val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
if (isDefined(classifier)) { if (isDefined(classifier)) {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
package org.apache.spark.ml.classification package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
...@@ -38,44 +38,59 @@ import org.apache.spark.sql.functions._ ...@@ -38,44 +38,59 @@ import org.apache.spark.sql.functions._
* It supports both binary and multiclass labels, as well as both continuous and categorical * It supports both binary and multiclass labels, as well as both continuous and categorical
* features. * features.
*/ */
@Since("1.4.0")
@Experimental @Experimental
final class RandomForestClassifier(override val uid: String) final class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestParams with TreeClassifierParams { with RandomForestParams with TreeClassifierParams {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("rfc")) def this() = this(Identifiable.randomUID("rfc"))
// Override parameter setters from parent trait for Java API compatibility. // Override parameter setters from parent trait for Java API compatibility.
// Parameters from TreeClassifierParams: // Parameters from TreeClassifierParams:
@Since("1.4.0")
override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
@Since("1.4.0")
override def setMaxBins(value: Int): this.type = super.setMaxBins(value) override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
@Since("1.4.0")
override def setMinInstancesPerNode(value: Int): this.type = override def setMinInstancesPerNode(value: Int): this.type =
super.setMinInstancesPerNode(value) super.setMinInstancesPerNode(value)
@Since("1.4.0")
override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
@Since("1.4.0")
override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
@Since("1.4.0")
override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
@Since("1.4.0")
override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value)
@Since("1.4.0")
override def setImpurity(value: String): this.type = super.setImpurity(value) override def setImpurity(value: String): this.type = super.setImpurity(value)
// Parameters from TreeEnsembleParams: // Parameters from TreeEnsembleParams:
@Since("1.4.0")
override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value) override def setSubsamplingRate(value: Double): this.type = super.setSubsamplingRate(value)
@Since("1.4.0")
override def setSeed(value: Long): this.type = super.setSeed(value) override def setSeed(value: Long): this.type = super.setSeed(value)
// Parameters from RandomForestParams: // Parameters from RandomForestParams:
@Since("1.4.0")
override def setNumTrees(value: Int): this.type = super.setNumTrees(value) override def setNumTrees(value: Int): this.type = super.setNumTrees(value)
@Since("1.4.0")
override def setFeatureSubsetStrategy(value: String): this.type = override def setFeatureSubsetStrategy(value: String): this.type =
super.setFeatureSubsetStrategy(value) super.setFeatureSubsetStrategy(value)
...@@ -99,15 +114,19 @@ final class RandomForestClassifier(override val uid: String) ...@@ -99,15 +114,19 @@ final class RandomForestClassifier(override val uid: String)
new RandomForestClassificationModel(trees, numFeatures, numClasses) new RandomForestClassificationModel(trees, numFeatures, numClasses)
} }
@Since("1.4.1")
override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra) override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
} }
@Since("1.4.0")
@Experimental @Experimental
object RandomForestClassifier { object RandomForestClassifier {
/** Accessor for supported impurity settings: entropy, gini */ /** Accessor for supported impurity settings: entropy, gini */
@Since("1.4.0")
final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
/** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */
@Since("1.4.0")
final val supportedFeatureSubsetStrategies: Array[String] = final val supportedFeatureSubsetStrategies: Array[String] =
RandomForestParams.supportedFeatureSubsetStrategies RandomForestParams.supportedFeatureSubsetStrategies
} }
...@@ -120,12 +139,13 @@ object RandomForestClassifier { ...@@ -120,12 +139,13 @@ object RandomForestClassifier {
* @param _trees Decision trees in the ensemble. * @param _trees Decision trees in the ensemble.
* Warning: These have null parents. * Warning: These have null parents.
*/ */
@Since("1.4.0")
@Experimental @Experimental
final class RandomForestClassificationModel private[ml] ( final class RandomForestClassificationModel private[ml] (
override val uid: String, @Since("1.5.0") override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel], private val _trees: Array[DecisionTreeClassificationModel],
override val numFeatures: Int, @Since("1.6.0") override val numFeatures: Int,
override val numClasses: Int) @Since("1.5.0") override val numClasses: Int)
extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
with TreeEnsembleModel with Serializable { with TreeEnsembleModel with Serializable {
...@@ -141,11 +161,13 @@ final class RandomForestClassificationModel private[ml] ( ...@@ -141,11 +161,13 @@ final class RandomForestClassificationModel private[ml] (
numClasses: Int) = numClasses: Int) =
this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
@Since("1.4.0")
override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
// Note: We may add support for weights (based on tree performance) later on. // Note: We may add support for weights (based on tree performance) later on.
private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0)
@Since("1.4.0")
override def treeWeights: Array[Double] = _treeWeights override def treeWeights: Array[Double] = _treeWeights
override protected def transformImpl(dataset: DataFrame): DataFrame = { override protected def transformImpl(dataset: DataFrame): DataFrame = {
...@@ -186,11 +208,13 @@ final class RandomForestClassificationModel private[ml] ( ...@@ -186,11 +208,13 @@ final class RandomForestClassificationModel private[ml] (
} }
} }
@Since("1.4.0")
override def copy(extra: ParamMap): RandomForestClassificationModel = { override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
.setParent(parent) .setParent(parent)
} }
@Since("1.4.0")
override def toString: String = { override def toString: String = {
s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" s"RandomForestClassificationModel (uid=$uid) with $numTrees trees"
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment