Skip to content
Snippets Groups Projects
Commit e3727c40 authored by Takahashi Hiroshi's avatar Takahashi Hiroshi Committed by Xiangrui Meng
Browse files

[SPARK-10263][ML] Add @Since annotation to ml.param and ml.*

Add Since annotations to ml.param and ml.*

Author: Takahashi Hiroshi <takahashi.hiroshi@lab.ntt.co.jp>
Author: Hiroshi Takahashi <takahashi.hiroshi@lab.ntt.co.jp>

Closes #8935 from taishi-oss/issue10263.
parent ab4a6bfd
No related branches found
No related tags found
No related merge requests found
...@@ -85,25 +85,32 @@ abstract class PipelineStage extends Params with Logging { ...@@ -85,25 +85,32 @@ abstract class PipelineStage extends Params with Logging {
* transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as
* an identity transformer. * an identity transformer.
*/ */
@Since("1.2.0")
@Experimental @Experimental
class Pipeline(override val uid: String) extends Estimator[PipelineModel] with MLWritable { class Pipeline @Since("1.4.0") (
@Since("1.4.0") override val uid: String) extends Estimator[PipelineModel] with MLWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("pipeline")) def this() = this(Identifiable.randomUID("pipeline"))
/** /**
* param for pipeline stages * param for pipeline stages
* @group param * @group param
*/ */
@Since("1.2.0")
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
/** @group setParam */ /** @group setParam */
@Since("1.2.0")
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
// Below, we clone stages so that modifications to the list of stages will not change // Below, we clone stages so that modifications to the list of stages will not change
// the Param value in the Pipeline. // the Param value in the Pipeline.
/** @group getParam */ /** @group getParam */
@Since("1.2.0")
def getStages: Array[PipelineStage] = $(stages).clone() def getStages: Array[PipelineStage] = $(stages).clone()
@Since("1.4.0")
override def validateParams(): Unit = { override def validateParams(): Unit = {
super.validateParams() super.validateParams()
$(stages).foreach(_.validateParams()) $(stages).foreach(_.validateParams())
...@@ -121,6 +128,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M ...@@ -121,6 +128,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M
* @param dataset input dataset * @param dataset input dataset
* @return fitted pipeline * @return fitted pipeline
*/ */
@Since("1.2.0")
override def fit(dataset: DataFrame): PipelineModel = { override def fit(dataset: DataFrame): PipelineModel = {
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
val theStages = $(stages) val theStages = $(stages)
...@@ -158,12 +166,14 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M ...@@ -158,12 +166,14 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M
new PipelineModel(uid, transformers.toArray).setParent(this) new PipelineModel(uid, transformers.toArray).setParent(this)
} }
@Since("1.4.0")
override def copy(extra: ParamMap): Pipeline = { override def copy(extra: ParamMap): Pipeline = {
val map = extractParamMap(extra) val map = extractParamMap(extra)
val newStages = map(stages).map(_.copy(extra)) val newStages = map(stages).map(_.copy(extra))
new Pipeline().setStages(newStages) new Pipeline().setStages(newStages)
} }
@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateParams() validateParams()
val theStages = $(stages) val theStages = $(stages)
...@@ -275,10 +285,11 @@ object Pipeline extends MLReadable[Pipeline] { ...@@ -275,10 +285,11 @@ object Pipeline extends MLReadable[Pipeline] {
* :: Experimental :: * :: Experimental ::
* Represents a fitted pipeline. * Represents a fitted pipeline.
*/ */
@Since("1.2.0")
@Experimental @Experimental
class PipelineModel private[ml] ( class PipelineModel private[ml] (
override val uid: String, @Since("1.4.0") override val uid: String,
val stages: Array[Transformer]) @Since("1.4.0") val stages: Array[Transformer])
extends Model[PipelineModel] with MLWritable with Logging { extends Model[PipelineModel] with MLWritable with Logging {
/** A Java/Python-friendly auxiliary constructor. */ /** A Java/Python-friendly auxiliary constructor. */
...@@ -286,21 +297,25 @@ class PipelineModel private[ml] ( ...@@ -286,21 +297,25 @@ class PipelineModel private[ml] (
this(uid, stages.asScala.toArray) this(uid, stages.asScala.toArray)
} }
@Since("1.4.0")
override def validateParams(): Unit = { override def validateParams(): Unit = {
super.validateParams() super.validateParams()
stages.foreach(_.validateParams()) stages.foreach(_.validateParams())
} }
@Since("1.2.0")
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true) transformSchema(dataset.schema, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur))
} }
@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateParams() validateParams()
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
} }
@Since("1.4.0")
override def copy(extra: ParamMap): PipelineModel = { override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
} }
......
...@@ -27,7 +27,7 @@ import scala.collection.JavaConverters._ ...@@ -27,7 +27,7 @@ import scala.collection.JavaConverters._
import org.json4s._ import org.json4s._
import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.{Vector, Vectors}
...@@ -504,8 +504,11 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In ...@@ -504,8 +504,11 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In
* :: Experimental :: * :: Experimental ::
* A param and its value. * A param and its value.
*/ */
@Since("1.2.0")
@Experimental @Experimental
case class ParamPair[T](param: Param[T], value: T) { case class ParamPair[T] @Since("1.2.0") (
@Since("1.2.0") param: Param[T],
@Since("1.2.0") value: T) {
// This is *the* place Param.validate is called. Whenever a parameter is specified, we should // This is *the* place Param.validate is called. Whenever a parameter is specified, we should
// always construct a ParamPair so that validate is called. // always construct a ParamPair so that validate is called.
param.validate(value) param.validate(value)
...@@ -786,6 +789,7 @@ abstract class JavaParams extends Params ...@@ -786,6 +789,7 @@ abstract class JavaParams extends Params
* :: Experimental :: * :: Experimental ::
* A param to value map. * A param to value map.
*/ */
@Since("1.2.0")
@Experimental @Experimental
final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
extends Serializable { extends Serializable {
...@@ -799,17 +803,20 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -799,17 +803,20 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Creates an empty param map. * Creates an empty param map.
*/ */
@Since("1.2.0")
def this() = this(mutable.Map.empty) def this() = this(mutable.Map.empty)
/** /**
* Puts a (param, value) pair (overwrites if the input param exists). * Puts a (param, value) pair (overwrites if the input param exists).
*/ */
@Since("1.2.0")
def put[T](param: Param[T], value: T): this.type = put(param -> value) def put[T](param: Param[T], value: T): this.type = put(param -> value)
/** /**
* Puts a list of param pairs (overwrites if the input params exists). * Puts a list of param pairs (overwrites if the input params exists).
*/ */
@varargs @varargs
@Since("1.2.0")
def put(paramPairs: ParamPair[_]*): this.type = { def put(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p => paramPairs.foreach { p =>
map(p.param.asInstanceOf[Param[Any]]) = p.value map(p.param.asInstanceOf[Param[Any]]) = p.value
...@@ -820,6 +827,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -820,6 +827,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Optionally returns the value associated with a param. * Optionally returns the value associated with a param.
*/ */
@Since("1.2.0")
def get[T](param: Param[T]): Option[T] = { def get[T](param: Param[T]): Option[T] = {
map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] map.get(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
} }
...@@ -827,6 +835,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -827,6 +835,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Returns the value associated with a param or a default value. * Returns the value associated with a param or a default value.
*/ */
@Since("1.4.0")
def getOrElse[T](param: Param[T], default: T): T = { def getOrElse[T](param: Param[T], default: T): T = {
get(param).getOrElse(default) get(param).getOrElse(default)
} }
...@@ -835,6 +844,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -835,6 +844,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
* Gets the value of the input param or its default value if it does not exist. * Gets the value of the input param or its default value if it does not exist.
* Raises a NoSuchElementException if there is no value associated with the input param. * Raises a NoSuchElementException if there is no value associated with the input param.
*/ */
@Since("1.2.0")
def apply[T](param: Param[T]): T = { def apply[T](param: Param[T]): T = {
get(param).getOrElse { get(param).getOrElse {
throw new NoSuchElementException(s"Cannot find param ${param.name}.") throw new NoSuchElementException(s"Cannot find param ${param.name}.")
...@@ -844,6 +854,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -844,6 +854,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Checks whether a parameter is explicitly specified. * Checks whether a parameter is explicitly specified.
*/ */
@Since("1.2.0")
def contains(param: Param[_]): Boolean = { def contains(param: Param[_]): Boolean = {
map.contains(param.asInstanceOf[Param[Any]]) map.contains(param.asInstanceOf[Param[Any]])
} }
...@@ -851,6 +862,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -851,6 +862,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Removes a key from this map and returns its value associated previously as an option. * Removes a key from this map and returns its value associated previously as an option.
*/ */
@Since("1.4.0")
def remove[T](param: Param[T]): Option[T] = { def remove[T](param: Param[T]): Option[T] = {
map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]] map.remove(param.asInstanceOf[Param[Any]]).asInstanceOf[Option[T]]
} }
...@@ -858,6 +870,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -858,6 +870,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Filters this param map for the given parent. * Filters this param map for the given parent.
*/ */
@Since("1.2.0")
def filter(parent: Params): ParamMap = { def filter(parent: Params): ParamMap = {
// Don't use filterKeys because mutable.Map#filterKeys // Don't use filterKeys because mutable.Map#filterKeys
// returns the instance of collections.Map, not mutable.Map. // returns the instance of collections.Map, not mutable.Map.
...@@ -870,8 +883,10 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -870,8 +883,10 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Creates a copy of this param map. * Creates a copy of this param map.
*/ */
@Since("1.2.0")
def copy: ParamMap = new ParamMap(map.clone()) def copy: ParamMap = new ParamMap(map.clone())
@Since("1.2.0")
override def toString: String = { override def toString: String = {
map.toSeq.sortBy(_._1.name).map { case (param, value) => map.toSeq.sortBy(_._1.name).map { case (param, value) =>
s"\t${param.parent}-${param.name}: $value" s"\t${param.parent}-${param.name}: $value"
...@@ -882,6 +897,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -882,6 +897,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
* Returns a new param map that contains parameters in this map and the given map, * Returns a new param map that contains parameters in this map and the given map,
* where the latter overwrites this if there exist conflicts. * where the latter overwrites this if there exist conflicts.
*/ */
@Since("1.2.0")
def ++(other: ParamMap): ParamMap = { def ++(other: ParamMap): ParamMap = {
// TODO: Provide a better method name for Java users. // TODO: Provide a better method name for Java users.
new ParamMap(this.map ++ other.map) new ParamMap(this.map ++ other.map)
...@@ -890,6 +906,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -890,6 +906,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Adds all parameters from the input param map into this param map. * Adds all parameters from the input param map into this param map.
*/ */
@Since("1.2.0")
def ++=(other: ParamMap): this.type = { def ++=(other: ParamMap): this.type = {
// TODO: Provide a better method name for Java users. // TODO: Provide a better method name for Java users.
this.map ++= other.map this.map ++= other.map
...@@ -899,6 +916,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -899,6 +916,7 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Converts this param map to a sequence of param pairs. * Converts this param map to a sequence of param pairs.
*/ */
@Since("1.2.0")
def toSeq: Seq[ParamPair[_]] = { def toSeq: Seq[ParamPair[_]] = {
map.toSeq.map { case (param, value) => map.toSeq.map { case (param, value) =>
ParamPair(param, value) ParamPair(param, value)
...@@ -908,21 +926,25 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) ...@@ -908,21 +926,25 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
/** /**
* Number of param pairs in this map. * Number of param pairs in this map.
*/ */
@Since("1.3.0")
def size: Int = map.size def size: Int = map.size
} }
@Since("1.2.0")
@Experimental @Experimental
object ParamMap { object ParamMap {
/** /**
* Returns an empty param map. * Returns an empty param map.
*/ */
@Since("1.2.0")
def empty: ParamMap = new ParamMap() def empty: ParamMap = new ParamMap()
/** /**
* Constructs a param map by specifying its entries. * Constructs a param map by specifying its entries.
*/ */
@varargs @varargs
@Since("1.2.0")
def apply(paramPairs: ParamPair[_]*): ParamMap = { def apply(paramPairs: ParamPair[_]*): ParamMap = {
new ParamMap().put(paramPairs: _*) new ParamMap().put(paramPairs: _*)
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment