Skip to content
Snippets Groups Projects
Commit 4b70798c authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[MINOR] [ML] change MultilayerPerceptronClassifierModel to MultilayerPerceptronClassificationModel

To follow the naming rule of ML, change `MultilayerPerceptronClassifierModel` to `MultilayerPerceptronClassificationModel` like `DecisionTreeClassificationModel`, `GBTClassificationModel` and so on.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8164 from yanboliang/mlp-name.
parent 7a539ef3
No related branches found
No related tags found
No related merge requests found
...@@ -131,7 +131,7 @@ private object LabelConverter { ...@@ -131,7 +131,7 @@ private object LabelConverter {
*/ */
@Experimental @Experimental
class MultilayerPerceptronClassifier(override val uid: String) class MultilayerPerceptronClassifier(override val uid: String)
extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassifierModel] extends Predictor[Vector, MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]
with MultilayerPerceptronParams { with MultilayerPerceptronParams {
def this() = this(Identifiable.randomUID("mlpc")) def this() = this(Identifiable.randomUID("mlpc"))
...@@ -146,7 +146,7 @@ class MultilayerPerceptronClassifier(override val uid: String) ...@@ -146,7 +146,7 @@ class MultilayerPerceptronClassifier(override val uid: String)
* @param dataset Training dataset * @param dataset Training dataset
* @return Fitted model * @return Fitted model
*/ */
override protected def train(dataset: DataFrame): MultilayerPerceptronClassifierModel = { override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = {
val myLayers = $(layers) val myLayers = $(layers)
val labels = myLayers.last val labels = myLayers.last
val lpData = extractLabeledPoints(dataset) val lpData = extractLabeledPoints(dataset)
...@@ -156,13 +156,13 @@ class MultilayerPerceptronClassifier(override val uid: String) ...@@ -156,13 +156,13 @@ class MultilayerPerceptronClassifier(override val uid: String)
FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter))
FeedForwardTrainer.setStackSize($(blockSize)) FeedForwardTrainer.setStackSize($(blockSize))
val mlpModel = FeedForwardTrainer.train(data) val mlpModel = FeedForwardTrainer.train(data)
new MultilayerPerceptronClassifierModel(uid, myLayers, mlpModel.weights()) new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights())
} }
} }
/** /**
* :: Experimental :: * :: Experimental ::
* Classifier model based on the Multilayer Perceptron. * Classification model based on the Multilayer Perceptron.
* Each layer has sigmoid activation function, output layer has softmax. * Each layer has sigmoid activation function, output layer has softmax.
* @param uid uid * @param uid uid
* @param layers array of layer sizes including input and output layers * @param layers array of layer sizes including input and output layers
...@@ -170,11 +170,11 @@ class MultilayerPerceptronClassifier(override val uid: String) ...@@ -170,11 +170,11 @@ class MultilayerPerceptronClassifier(override val uid: String)
* @return prediction model * @return prediction model
*/ */
@Experimental @Experimental
class MultilayerPerceptronClassifierModel private[ml] ( class MultilayerPerceptronClassificationModel private[ml] (
override val uid: String, override val uid: String,
layers: Array[Int], layers: Array[Int],
weights: Vector) weights: Vector)
extends PredictionModel[Vector, MultilayerPerceptronClassifierModel] extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
with Serializable { with Serializable {
private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
...@@ -187,7 +187,7 @@ class MultilayerPerceptronClassifierModel private[ml] ( ...@@ -187,7 +187,7 @@ class MultilayerPerceptronClassifierModel private[ml] (
LabelConverter.decodeLabel(mlpModel.predict(features)) LabelConverter.decodeLabel(mlpModel.predict(features))
} }
override def copy(extra: ParamMap): MultilayerPerceptronClassifierModel = { override def copy(extra: ParamMap): MultilayerPerceptronClassificationModel = {
copyValues(new MultilayerPerceptronClassifierModel(uid, layers, weights), extra) copyValues(new MultilayerPerceptronClassificationModel(uid, layers, weights), extra)
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment