From 5855e0057defeab8006ca4f7b0196003bbc9e899 Mon Sep 17 00:00:00 2001
From: Yuhao Yang <yuhao.yang@intel.com>
Date: Thu, 2 Jun 2016 16:37:01 -0700
Subject: [PATCH] [SPARK-15668][ML] ml.feature: update check schema to avoid
 confusion when user use MLlib.vector as input type

## What changes were proposed in this pull request?

ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type

## How was this patch tested?
existing ut

Author: Yuhao Yang <yuhao.yang@intel.com>

Closes #13411 from hhbyyh/schemaCheck.
---
 .../spark/ml/feature/MaxAbsScaler.scala       |  4 +--
 .../spark/ml/feature/MinMaxScaler.scala       |  4 +--
 .../org/apache/spark/ml/feature/PCA.scala     | 28 ++++++++-----------
 .../spark/ml/feature/StandardScaler.scala     | 25 ++++++++---------
 4 files changed, 25 insertions(+), 36 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 1b5159902e..7298a18ff8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -39,9 +39,7 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H
 
    /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
     require(!schema.fieldNames.contains($(outputCol)),
       s"Output column ${$(outputCol)} already exists.")
     val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index d15f1b8563..a27bed5333 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -63,9 +63,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
     require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
     require(!schema.fieldNames.contains($(outputCol)),
       s"Output column ${$(outputCol)} already exists.")
     val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index dbbaa5aa46..2f667af9d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -49,8 +49,16 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
   /** @group getParam */
   def getK: Int = $(k)
 
-}
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+    StructType(outputFields)
+  }
 
+}
 /**
  * :: Experimental ::
  * PCA trains a model to project vectors to a lower dimensional space of the top [[PCA!.k]]
@@ -86,13 +94,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): PCA = defaultCopy(extra)
@@ -148,13 +150,7 @@ class PCAModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): PCAModel = {
@@ -201,7 +197,7 @@ object PCAModel extends MLReadable[PCAModel] {
       val versionRegex = "([0-9]+)\\.([0-9]+).*".r
       val hasExplainedVariance = metadata.sparkVersion match {
         case versionRegex(major, minor) =>
-          (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6))
+          major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)
         case _ => false
       }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 9d084b520c..7cec369c23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -62,6 +62,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
   /** @group getParam */
   def getWithStd: Boolean = $(withStd)
 
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+    StructType(outputFields)
+  }
+
   setDefault(withMean -> false, withStd -> true)
 }
 
@@ -105,13 +114,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
@@ -159,13 +162,7 @@ class StandardScalerModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): StandardScalerModel = {
-- 
GitLab