Skip to content
Snippets Groups Projects
Commit f5ebb18c authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-14671][ML] Pipeline setStages should handle subclasses of PipelineStage

## What changes were proposed in this pull request?

Pipeline.setStages failed for some code examples which worked in 1.5 but fail in 1.6.  This tends to occur when using a mix of transformers from ml.feature. It is because Java Arrays are non-covariant and the addition of MLWritable to some transformers means the stages0/1 arrays above are not of type Array[PipelineStage].  This PR modifies the following to accept subclasses of PipelineStage:
* Pipeline.setStages()
* Params.w()

## How was this patch tested?

Unit test which fails to compile before this fix.

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #12430 from jkbradley/pipeline-setstages.
parent 6466d6c8
No related branches found
No related tags found
No related merge requests found
...@@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") ( ...@@ -103,7 +103,10 @@ class Pipeline @Since("1.4.0") (
/** @group setParam */ /** @group setParam */
@Since("1.2.0") @Since("1.2.0")
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } def setStages(value: Array[_ <: PipelineStage]): this.type = {
set(stages, value.asInstanceOf[Array[PipelineStage]])
this
}
// Below, we clone stages so that modifications to the list of stages will not change // Below, we clone stages so that modifications to the list of stages will not change
// the Param value in the Pipeline. // the Param value in the Pipeline.
......
...@@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock ...@@ -27,7 +27,7 @@ import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite import org.apache.spark.ml.Pipeline.SharedReadWrite
import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.param.{IntParam, ParamMap, ParamPair}
import org.apache.spark.ml.util._ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
...@@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul ...@@ -201,6 +201,13 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
pipeline.fit(df) pipeline.fit(df)
} }
} }
test("Pipeline.setStages should handle Java Arrays being non-covariant") {
val stages0 = Array(new UnWritableStage("b"))
val stages1 = Array(new WritableStage("a"))
val steps = stages0 ++ stages1
val p = new Pipeline().setStages(steps)
}
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment