Skip to content
Snippets Groups Projects
Commit 9ace2e5c authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-11852][ML] StandardScaler minor refactor

```withStd``` and ```withMean``` should be params of ```StandardScaler``` and ```StandardScalerModel```.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #9839 from yanboliang/standardScaler-refactor.
parent a66142de
No related branches found
No related tags found
No related merge requests found
......@@ -36,20 +36,30 @@ import org.apache.spark.sql.types.{StructField, StructType}
private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol {
/**
* Centers the data with mean before scaling.
* Whether to center the data with mean before scaling.
* It will build a dense output, so this does not work on sparse input
* and will raise an exception.
* Default: false
* @group param
*/
val withMean: BooleanParam = new BooleanParam(this, "withMean", "Center data with mean")
val withMean: BooleanParam = new BooleanParam(this, "withMean",
"Whether to center data with mean")
/** @group getParam */
def getWithMean: Boolean = $(withMean)
/**
* Scales the data to unit standard deviation.
* Whether to scale the data to unit standard deviation.
* Default: true
* @group param
*/
val withStd: BooleanParam = new BooleanParam(this, "withStd", "Scale to unit standard deviation")
val withStd: BooleanParam = new BooleanParam(this, "withStd",
"Whether to scale the data to unit standard deviation")
/** @group getParam */
def getWithStd: Boolean = $(withStd)
setDefault(withMean -> false, withStd -> true)
}
/**
......@@ -63,8 +73,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
def this() = this(Identifiable.randomUID("stdScal"))
setDefault(withMean -> false, withStd -> true)
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
......@@ -82,7 +90,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd))
val scalerModel = scaler.fit(input)
copyValues(new StandardScalerModel(uid, scalerModel).setParent(this))
copyValues(new StandardScalerModel(uid, scalerModel.std, scalerModel.mean).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
......@@ -108,29 +116,19 @@ object StandardScaler extends DefaultParamsReadable[StandardScaler] {
/**
* :: Experimental ::
* Model fitted by [[StandardScaler]].
*
* @param std Standard deviation of the StandardScalerModel
* @param mean Mean of the StandardScalerModel
*/
@Experimental
class StandardScalerModel private[ml] (
override val uid: String,
scaler: feature.StandardScalerModel)
val std: Vector,
val mean: Vector)
extends Model[StandardScalerModel] with StandardScalerParams with MLWritable {
import StandardScalerModel._
/** Standard deviation of the StandardScalerModel */
val std: Vector = scaler.std
/** Mean of the StandardScalerModel */
val mean: Vector = scaler.mean
/** Whether to scale to unit standard deviation. */
@Since("1.6.0")
def getWithStd: Boolean = scaler.withStd
/** Whether to center data with mean. */
@Since("1.6.0")
def getWithMean: Boolean = scaler.withMean
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
......@@ -139,6 +137,7 @@ class StandardScalerModel private[ml] (
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean))
val scale = udf { scaler.transform _ }
dataset.withColumn($(outputCol), scale(col($(inputCol))))
}
......@@ -154,7 +153,7 @@ class StandardScalerModel private[ml] (
}
override def copy(extra: ParamMap): StandardScalerModel = {
val copied = new StandardScalerModel(uid, scaler)
val copied = new StandardScalerModel(uid, std, mean)
copyValues(copied, extra).setParent(parent)
}
......@@ -168,11 +167,11 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
private[StandardScalerModel]
class StandardScalerModelWriter(instance: StandardScalerModel) extends MLWriter {
private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
private case class Data(std: Vector, mean: Vector)
override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
val data = Data(instance.std, instance.mean)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
......@@ -185,13 +184,10 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
sqlContext.read.parquet(dataPath)
.select("std", "mean", "withStd", "withMean")
.head()
// This is very likely to change in the future because withStd and withMean should be params.
val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
val model = new StandardScalerModel(metadata.uid, oldModel)
val Row(std: Vector, mean: Vector) = sqlContext.read.parquet(dataPath)
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
......
......@@ -70,8 +70,8 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
test("params") {
ParamsSuite.checkParams(new StandardScaler)
val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
ParamsSuite.checkParams(new StandardScalerModel("empty",
Vectors.dense(1.0), Vectors.dense(2.0)))
}
test("Standardization with default parameter") {
......@@ -126,13 +126,10 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
}
test("StandardScalerModel read/write") {
val oldModel = new feature.StandardScalerModel(
Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
val instance = new StandardScalerModel("myStandardScalerModel",
Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0))
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.std === instance.std)
assert(newInstance.mean === instance.mean)
assert(newInstance.getWithStd === instance.getWithStd)
assert(newInstance.getWithMean === instance.getWithMean)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment