From 5855e0057defeab8006ca4f7b0196003bbc9e899 Mon Sep 17 00:00:00 2001 From: Yuhao Yang <yuhao.yang@intel.com> Date: Thu, 2 Jun 2016 16:37:01 -0700 Subject: [PATCH] [SPARK-15668][ML] ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type ## What changes were proposed in this pull request? ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type ## How was this patch tested? existing ut Author: Yuhao Yang <yuhao.yang@intel.com> Closes #13411 from hhbyyh/schemaCheck. --- .../spark/ml/feature/MaxAbsScaler.scala | 4 +-- .../spark/ml/feature/MinMaxScaler.scala | 4 +-- .../org/apache/spark/ml/feature/PCA.scala | 28 ++++++++----------- .../spark/ml/feature/StandardScaler.scala | 25 ++++++++--------- 4 files changed, 25 insertions(+), 36 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 1b5159902e..7298a18ff8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -39,9 +39,7 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) require(!schema.fieldNames.contains($(outputCol)), s"Output column ${$(outputCol)} already exists.") val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index d15f1b8563..a27bed5333 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -63,9 +63,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) require(!schema.fieldNames.contains($(outputCol)), s"Output column ${$(outputCol)} already exists.") val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index dbbaa5aa46..2f667af9d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -49,8 +49,16 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC /** @group getParam */ def getK: Int = $(k) -} + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} /** * :: Experimental :: * PCA trains a model to project vectors to a lower dimensional space of the top [[PCA!.k]] @@ -86,13 +94,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): PCA = defaultCopy(extra) @@ -148,13 +150,7 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): PCAModel = { @@ -201,7 +197,7 @@ object PCAModel extends MLReadable[PCAModel] { val versionRegex = "([0-9]+)\\.([0-9]+).*".r val hasExplainedVariance = metadata.sparkVersion match { case versionRegex(major, minor) => - (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)) + major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6) case _ => false } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 9d084b520c..7cec369c23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -62,6 +62,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with /** @group getParam */ def getWithStd: Boolean = $(withStd) + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + setDefault(withMean -> false, withStd -> true) } @@ -105,13 +114,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra) @@ -159,13 +162,7 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - val inputType = schema($(inputCol)).dataType - require(inputType.isInstanceOf[VectorUDT], - s"Input column ${$(inputCol)} must be a vector column") - require(!schema.fieldNames.contains($(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) - StructType(outputFields) + validateAndTransformSchema(schema) } override def copy(extra: ParamMap): StandardScalerModel = { -- GitLab