Skip to content
Snippets Groups Projects
Commit 6c5a837c authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-12301][ML] Made all tree and ensemble classes not final

## What changes were proposed in this pull request?

There have been continuing requests (e.g., SPARK-7131) for allowing users to extend and modify MLlib models and algorithms.

This PR makes tree and ensemble classes, Node types, and Split types in spark.ml no longer final.  This matches most other spark.ml algorithms.

Constructors for models are still private since we may need to refactor how stats are maintained in tree nodes.

## How was this patch tested?

Existing unit tests

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #12711 from jkbradley/final-trees.
parent e88476c8
No related branches found
No related tags found
No related merge requests found
......@@ -44,7 +44,7 @@ import org.apache.spark.sql.Dataset
*/
@Since("1.4.0")
@Experimental
final class DecisionTreeClassifier @Since("1.4.0") (
class DecisionTreeClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeClassifierParams with DefaultParamsWritable {
......@@ -138,7 +138,7 @@ object DecisionTreeClassifier extends DefaultParamsReadable[DecisionTreeClassifi
*/
@Since("1.4.0")
@Experimental
final class DecisionTreeClassificationModel private[ml] (
class DecisionTreeClassificationModel private[ml] (
@Since("1.4.0")override val uid: String,
@Since("1.4.0")override val rootNode: Node,
@Since("1.6.0")override val numFeatures: Int,
......
......@@ -57,7 +57,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
final class GBTClassifier @Since("1.4.0") (
class GBTClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTClassifier, GBTClassificationModel]
with GBTClassifierParams with DefaultParamsWritable with Logging {
......@@ -170,7 +170,7 @@ object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
*/
@Since("1.6.0")
@Experimental
final class GBTClassificationModel private[ml](
class GBTClassificationModel private[ml](
@Since("1.6.0") override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
......
......@@ -44,7 +44,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
final class RandomForestClassifier @Since("1.4.0") (
class RandomForestClassifier @Since("1.4.0") (
@Since("1.4.0") override val uid: String)
extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
with RandomForestClassifierParams with DefaultParamsWritable {
......@@ -149,7 +149,7 @@ object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifi
*/
@Since("1.4.0")
@Experimental
final class RandomForestClassificationModel private[ml] (
class RandomForestClassificationModel private[ml] (
@Since("1.5.0") override val uid: String,
private val _trees: Array[DecisionTreeClassificationModel],
@Since("1.6.0") override val numFeatures: Int,
......
......@@ -45,7 +45,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
with DecisionTreeRegressorParams with DefaultParamsWritable {
......@@ -129,7 +129,7 @@ object DecisionTreeRegressor extends DefaultParamsReadable[DecisionTreeRegressor
*/
@Since("1.4.0")
@Experimental
final class DecisionTreeRegressionModel private[ml] (
class DecisionTreeRegressionModel private[ml] (
override val uid: String,
override val rootNode: Node,
override val numFeatures: Int)
......
......@@ -57,7 +57,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, GBTRegressor, GBTRegressionModel]
with GBTRegressorParams with DefaultParamsWritable with Logging {
......@@ -157,7 +157,7 @@ object GBTRegressor extends DefaultParamsReadable[GBTRegressor] {
*/
@Since("1.4.0")
@Experimental
final class GBTRegressionModel private[ml](
class GBTRegressionModel private[ml](
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
private val _treeWeights: Array[Double],
......
......@@ -43,7 +43,7 @@ import org.apache.spark.sql.functions._
*/
@Since("1.4.0")
@Experimental
final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel]
with RandomForestRegressorParams with DefaultParamsWritable {
......@@ -137,7 +137,7 @@ object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor
*/
@Since("1.4.0")
@Experimental
final class RandomForestRegressionModel private[ml] (
class RandomForestRegressionModel private[ml] (
override val uid: String,
private val _trees: Array[DecisionTreeRegressionModel],
override val numFeatures: Int)
......
......@@ -115,7 +115,7 @@ private[ml] object Node {
* @param impurity Impurity measure at this node (for training data)
*/
@DeveloperApi
final class LeafNode private[ml] (
class LeafNode private[ml] (
override val prediction: Double,
override val impurity: Double,
override private[ml] val impurityStats: ImpurityCalculator) extends Node {
......@@ -158,7 +158,7 @@ final class LeafNode private[ml] (
* @param split Information about the test used to split to the left or right child.
*/
@DeveloperApi
final class InternalNode private[ml] (
class InternalNode private[ml] (
override val prediction: Double,
override val impurity: Double,
val gain: Double,
......
......@@ -75,7 +75,7 @@ private[tree] object Split {
* @param numCategories Number of categories for this feature.
*/
@DeveloperApi
final class CategoricalSplit private[ml] (
class CategoricalSplit private[ml] (
override val featureIndex: Int,
_leftCategories: Array[Double],
@Since("2.0.0") val numCategories: Int)
......@@ -160,7 +160,7 @@ final class CategoricalSplit private[ml] (
* Otherwise, it goes right.
*/
@DeveloperApi
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
extends Split {
override private[ml] def shouldGoLeft(features: Vector): Boolean = {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment