Skip to content
Snippets Groups Projects
Commit f7b55dbf authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-10470] [ML] ml.IsotonicRegressionModel.copy should set parent

Copied model must have the same parent, but ml.IsotonicRegressionModel.copy did not set parent.
Here fix it and add test case.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #8637 from yanboliang/spark-10470.
parent 5fd57955
No related branches found
No related tags found
No related merge requests found
...@@ -203,7 +203,7 @@ class IsotonicRegressionModel private[ml] ( ...@@ -203,7 +203,7 @@ class IsotonicRegressionModel private[ml] (
def predictions: Vector = Vectors.dense(oldModel.predictions) def predictions: Vector = Vectors.dense(oldModel.predictions)
override def copy(extra: ParamMap): IsotonicRegressionModel = { override def copy(extra: ParamMap): IsotonicRegressionModel = {
copyValues(new IsotonicRegressionModel(uid, oldModel), extra) copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
} }
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
......
...@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression ...@@ -19,6 +19,7 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
...@@ -89,6 +90,10 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -89,6 +90,10 @@ class IsotonicRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(ir.getFeatureIndex === 0) assert(ir.getFeatureIndex === 0)
val model = ir.fit(dataset) val model = ir.fit(dataset)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
model.transform(dataset) model.transform(dataset)
.select("label", "features", "prediction", "weight") .select("label", "features", "prediction", "weight")
.collect() .collect()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment