diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 5607ed21afe1808f1c0c643754a32589d8c21369..5bbcd2e080e077077b9d0d88ca092fddc1b79759 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml
 import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
 abstract class PipelineStage extends Serializable with Logging {
 
   /**
+   * :: DeveloperAPI ::
+   *
    * Derives the output schema from the input schema and parameters.
+   * The schema describes the columns and types of the data.
+   *
+   * @param schema  Input schema to this stage
+   * @param paramMap  Parameters passed to this stage
+   * @return  Output schema from this stage
    */
-  private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+  @DeveloperApi
+  def transformSchema(schema: StructType, paramMap: ParamMap): StructType
 
   /**
    * Derives the output schema from the input schema and parameters, optionally with logging.
@@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
     new PipelineModel(this, map, transformers.toArray)
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val theStages = map(stages)
     require(theStages.toSet.size == theStages.size,
@@ -171,7 +179,7 @@ class PipelineModel private[ml] (
     stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
     val map = (fittingParamMap ++ this.paramMap) ++ paramMap
     stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index ddbd648d64f23e774a69c4f3118dce18fd82b75d..1142aa4f8e73de74342270867988d86893fd76df 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
     model
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val inputType = schema(map(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
@@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
     dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val inputType = schema(map(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 7daeff980f0eac3db44ee000903fe3eecd6714b6..dfb89cc8d4af314956e54bf1b895aaec0f5b24a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
   @DeveloperApi
   protected def featuresDataType: DataType = new VectorUDT
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
   }
 
@@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
   @DeveloperApi
   protected def featuresDataType: DataType = new VectorUDT
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 8d70e4347c4c9ca4d4efa4f42d664b86ce53146c..c2ec716f08b7c6055fc55dd82f48954f58da280b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -188,7 +188,7 @@ class ALSModel private[ml] (
       .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
   }
 
-  override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap)
   }
 }
@@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
     model
   }
 
-  override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap)
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index b07a68269cc2b1699b463a9d5c97ea4dfd196504..2eb1dac56f1e9450ae739bc0153e5e9bbe67658c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
     cvModel
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     map(estimator).transformSchema(schema, paramMap)
   }
@@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
     bestModel.transform(dataset, paramMap)
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     bestModel.transformSchema(schema, paramMap)
   }
 }