Skip to content
Snippets Groups Projects
Commit b89d0556 authored by Wojciech Szymanski's avatar Wojciech Szymanski Committed by Yanbo Liang
Browse files

[SPARK-18210][ML] Pipeline.copy does not create an instance with the same UID

## What changes were proposed in this pull request?

Motivation:
`org.apache.spark.ml.Pipeline.copy(extra: ParamMap)` does not create an instance with the same UID. It does not conform to the method specification from its base class `org.apache.spark.ml.param.Params.copy(extra: ParamMap)`

Solution:
- fix for Pipeline UID
- introduced new tests for `org.apache.spark.ml.Pipeline.copy`
- minor improvements in test for `org.apache.spark.ml.PipelineModel.copy`

## How was this patch tested?

Introduced new unit test: `org.apache.spark.ml.PipelineSuite."Pipeline.copy"`
Improved existing unit test: `org.apache.spark.ml.PipelineSuite."PipelineModel.copy"`

Author: Wojciech Szymanski <wk.szymanski@gmail.com>

Closes #15759 from wojtek-szymanski/SPARK-18210.
parent 340f09d1
No related branches found
No related tags found
No related merge requests found
......@@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") (
override def copy(extra: ParamMap): Pipeline = {
val map = extractParamMap(extra)
val newStages = map(stages).map(_.copy(extra))
new Pipeline().setStages(newStages)
new Pipeline(uid).setStages(newStages)
}
@Since("1.2.0")
......
......@@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
test("Pipeline.copy") {
val hashingTF = new HashingTF()
.setNumFeatures(100)
val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF))
val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10))
assert(copied.uid === pipeline.uid,
"copy should create an instance with the same UID")
assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
"copy should handle extra stage params")
}
test("PipelineModel.copy") {
val hashingTF = new HashingTF()
.setNumFeatures(100)
val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF))
.setParent(new Pipeline())
val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
assert(copied.uid === model.uid,
"copy should create an instance with the same UID")
assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
"copy should handle extra stage params")
assert(copied.parent === model.parent,
"copy should create an instance with the same parent")
}
test("pipeline model constructors") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment