From 40e6d40fe79ce45d511e049133d2f30a2963740b Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Mon, 22 Feb 2016 12:59:50 +0200
Subject: [PATCH] [SPARK-13334][ML] ML KMeansModel / BisectingKMeansModel /
 QuantileDiscretizer should set parent

ML ```KMeansModel / BisectingKMeansModel / QuantileDiscretizer``` should set parent.

cc mengxr

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #11214 from yanboliang/spark-13334.
---
 .../org/apache/spark/ml/clustering/BisectingKMeans.scala      | 2 +-
 .../main/scala/org/apache/spark/ml/clustering/KMeans.scala    | 2 +-
 .../org/apache/spark/ml/feature/QuantileDiscretizer.scala     | 2 +-
 .../org/apache/spark/ml/clustering/BisectingKMeansSuite.scala | 1 +
 .../scala/org/apache/spark/ml/clustering/KMeansSuite.scala    | 1 +
 .../apache/spark/ml/feature/QuantileDiscretizerSuite.scala    | 4 +++-
 6 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 0b47cbbac8..45d293bc69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -185,7 +185,7 @@ class BisectingKMeans @Since("2.0.0") (
       .setSeed($(seed))
     val parentModel = bkm.run(rdd)
     val model = new BisectingKMeansModel(uid, parentModel)
-    copyValues(model)
+    copyValues(model.setParent(this))
   }
 
   @Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index dc6d5d9280..b2292e20e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -250,7 +250,7 @@ class KMeans @Since("1.5.0") (
       .setEpsilon($(tol))
     val parentModel = algo.run(rdd)
     val model = new KMeansModel(uid, parentModel)
-    copyValues(model)
+    copyValues(model.setParent(this))
   }
 
   @Since("1.5.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 2a294d3881..1f4cca1233 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -95,7 +95,7 @@ final class QuantileDiscretizer(override val uid: String)
     val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
     val splits = QuantileDiscretizer.getSplits(candidates)
     val bucketizer = new Bucketizer(uid).setSplits(splits)
-    copyValues(bucketizer)
+    copyValues(bucketizer.setParent(this))
   }
 
   override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index b26571eb9f..fc4a4add5d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -81,5 +81,6 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(clusters.size === k)
     assert(clusters === Set(0, 1, 2, 3, 4))
     assert(model.computeCost(dataset) < 0.1)
+    assert(model.hasParent)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2724e51f31..e5357ba8e2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -97,6 +97,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
     assert(clusters.size === k)
     assert(clusters === Set(0, 1, 2, 3, 4))
     assert(model.computeCost(dataset) < 0.1)
+    assert(model.hasParent)
   }
 
   test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index 4fde42972f..6a2c601bbe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -94,7 +94,9 @@ private object QuantileDiscretizerSuite extends SparkFunSuite {
     val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input")
     val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result")
       .setNumBuckets(numBucket).setSeed(1)
-    val result = discretizer.fit(df).transform(df)
+    val model = discretizer.fit(df)
+    assert(model.hasParent)
+    val result = model.transform(df)
 
     val transformedFeatures = result.select("result").collect()
       .map { case Row(transformedFeature: Double) => transformedFeature }
-- 
GitLab