From a5fed34355b403188ad50b567ab62ee54597b493 Mon Sep 17 00:00:00 2001
From: "Joseph K. Bradley" <joseph@databricks.com>
Date: Thu, 19 Feb 2015 12:46:27 -0800
Subject: [PATCH] [SPARK-5902] [ml] Made PipelineStage.transformSchema public
 instead of private to ml

For users to implement their own PipelineStages, we need to make PipelineStage.transformSchema be public instead of private to ml.  This would be nice to include in Spark 1.3

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4682 from jkbradley/SPARK-5902 and squashes the following commits:

6f02357 [Joseph K. Bradley] Made transformSchema public
0e6d0a0 [Joseph K. Bradley] made implementations of transformSchema protected as well
fdaf26a [Joseph K. Bradley] Made PipelineStage.transformSchema protected instead of private[ml]
---
 .../scala/org/apache/spark/ml/Pipeline.scala     | 16 ++++++++++++----
 .../apache/spark/ml/feature/StandardScaler.scala |  4 ++--
 .../spark/ml/impl/estimator/Predictor.scala      |  4 ++--
 .../org/apache/spark/ml/recommendation/ALS.scala |  4 ++--
 .../apache/spark/ml/tuning/CrossValidator.scala  |  4 ++--
 5 files changed, 20 insertions(+), 12 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 5607ed21af..5bbcd2e080 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml
 import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
 abstract class PipelineStage extends Serializable with Logging {
 
   /**
+   * :: DeveloperAPI ::
+   *
    * Derives the output schema from the input schema and parameters.
+   * The schema describes the columns and types of the data.
+   *
+   * @param schema  Input schema to this stage
+   * @param paramMap  Parameters passed to this stage
+   * @return  Output schema from this stage
    */
-  private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+  @DeveloperApi
+  def transformSchema(schema: StructType, paramMap: ParamMap): StructType
 
   /**
    * Derives the output schema from the input schema and parameters, optionally with logging.
@@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
     new PipelineModel(this, map, transformers.toArray)
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val theStages = map(stages)
     require(theStages.toSet.size == theStages.size,
@@ -171,7 +179,7 @@ class PipelineModel private[ml] (
     stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
     val map = (fittingParamMap ++ this.paramMap) ++ paramMap
     stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index ddbd648d64..1142aa4f8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
     model
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val inputType = schema(map(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
@@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
     dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val inputType = schema(map(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 7daeff980f..dfb89cc8d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
   @DeveloperApi
   protected def featuresDataType: DataType = new VectorUDT
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
   }
 
@@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
   @DeveloperApi
   protected def featuresDataType: DataType = new VectorUDT
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 8d70e4347c..c2ec716f08 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -188,7 +188,7 @@ class ALSModel private[ml] (
       .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
   }
 
-  override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap)
   }
 }
@@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
     model
   }
 
-  override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap)
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index b07a68269c..2eb1dac56f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
     cvModel
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     map(estimator).transformSchema(schema, paramMap)
   }
@@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
     bestModel.transform(dataset, paramMap)
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     bestModel.transformSchema(schema, paramMap)
   }
 }
-- 
GitLab