From 1fa58868bc6635ff2119264665bd3d00b4b1253a Mon Sep 17 00:00:00 2001
From: Yanbo Liang <ybliang8@gmail.com>
Date: Wed, 8 Mar 2017 02:05:01 -0800
Subject: [PATCH] [ML][MINOR] Separate estimator and model params for
 read/write test.

## What changes were proposed in this pull request?
Since we allow ```Estimator``` and ```Model``` not always share same params (see ```ALSParams``` and ```ALSModelParams```), we should pass in test params for estimator and model separately in function ```testEstimatorAndModelReadWrite```.

## How was this patch tested?
Existing tests.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #17151 from yanboliang/test-rw.
---
 .../DecisionTreeClassifierSuite.scala         |  8 +++--
 .../classification/GBTClassifierSuite.scala   |  3 +-
 .../ml/classification/LinearSVCSuite.scala    |  2 +-
 .../LogisticRegressionSuite.scala             |  2 +-
 .../ml/classification/NaiveBayesSuite.scala   |  3 +-
 .../RandomForestClassifierSuite.scala         |  3 +-
 .../ml/clustering/BisectingKMeansSuite.scala  |  4 +--
 .../ml/clustering/GaussianMixtureSuite.scala  |  2 +-
 .../spark/ml/clustering/KMeansSuite.scala     |  3 +-
 .../apache/spark/ml/clustering/LDASuite.scala |  4 ++-
 .../BucketedRandomProjectionLSHSuite.scala    |  2 +-
 .../spark/ml/feature/ChiSqSelectorSuite.scala |  3 +-
 .../spark/ml/feature/MinHashLSHSuite.scala    |  2 +-
 .../apache/spark/ml/fpm/FPGrowthSuite.scala   |  4 +--
 .../spark/ml/recommendation/ALSSuite.scala    | 35 +++++++------------
 .../AFTSurvivalRegressionSuite.scala          |  3 +-
 .../DecisionTreeRegressorSuite.scala          |  5 +--
 .../ml/regression/GBTRegressorSuite.scala     |  3 +-
 .../GeneralizedLinearRegressionSuite.scala    |  1 +
 .../regression/IsotonicRegressionSuite.scala  |  2 +-
 .../ml/regression/LinearRegressionSuite.scala |  2 +-
 .../RandomForestRegressorSuite.scala          |  3 +-
 .../spark/ml/util/DefaultReadWriteTest.scala  | 14 ++++----
 23 files changed, 59 insertions(+), 54 deletions(-)

diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index c711e7fa9d..10de50306a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite
     // Categorical splits with tree depth 2
     val categoricalData: DataFrame =
       TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2)
-    testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings,
+      allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 2
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
-    testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 0
     testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0),
-      checkModelData)
+      allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 0598943c3d..0cddb37281 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
 
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
-    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index fe47176a4a..4c63a2a88c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -232,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
     }
     val svm = new LinearSVC()
     testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings,
-      checkModelData)
+      LinearSVCSuite.allParamSettings, checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index d89a958eed..affaa57374 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2089,7 +2089,7 @@ class LogisticRegressionSuite
     }
     val lr = new LogisticRegression()
     testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings,
-      checkModelData)
+      LogisticRegressionSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 37d7991fe8..4d5d299d14 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
       assert(model.theta === model2.theta)
     }
     val nb = new NaiveBayes()
-    testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
+      NaiveBayesSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 44e1585ee5..c3003cec73 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -218,7 +218,8 @@ class RandomForestClassifierSuite
 
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2)
-    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 30513c1e27..200a892f6c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -138,8 +138,8 @@ class BisectingKMeansSuite
       assert(model.clusterCenters === model2.clusterCenters)
     }
     val bisectingKMeans = new BisectingKMeans()
-    testEstimatorAndModelReadWrite(
-      bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings,
+      BisectingKMeansSuite.allParamSettings, checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index c500c5b3e3..61da897b66 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
       assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov))
     }
     val gm = new GaussianMixture()
-    testEstimatorAndModelReadWrite(gm, dataset,
+    testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings,
       GaussianMixtureSuite.allParamSettings, checkModelData)
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index e10127f7d1..ca05b9c389 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
       assert(model.clusterCenters === model2.clusterCenters)
     }
     val kmeans = new KMeans()
-    testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings,
+      KMeansSuite.allParamSettings, checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index 9aa11fbdbe..75aa0be61a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
         Vectors.dense(model2.getDocConcentration) absTol 1e-6)
     }
     val lda = new LDA()
-    testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings,
+      LDASuite.allParamSettings, checkModelData)
   }
 
   test("read/write DistributedLDAModel") {
@@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
     }
     val lda = new LDA()
     testEstimatorAndModelReadWrite(lda, dataset,
+      LDASuite.allParamSettings ++ Map("optimizer" -> "em"),
       LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData)
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index ab937685a5..91eac9e733 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite
     }
     val mh = new BucketedRandomProjectionLSH()
     val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
-    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+    testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
   }
 
   test("hashFunction") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index 482e5d5426..d6925da97d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
       assert(model.selectedFeatures === model2.selectedFeatures)
     }
     val nb = new ChiSqSelector
-    testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings,
+      ChiSqSelectorSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 3461cdf824..a2f009310f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     }
     val mh = new MinHashLSH()
     val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
-    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+    testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData)
   }
 
   test("hashFunction") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 74c7461401..076d55c180 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
         model2.freqItemsets.sort("items").collect())
     }
     val fPGrowth = new FPGrowth()
-    testEstimatorAndModelReadWrite(
-      fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
+      FPGrowthSuite.allParamSettings, checkModelData)
   }
 
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index e494ea89e6..a177ed13bf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -518,37 +518,26 @@ class ALSSuite
   }
 
   test("read/write") {
-    import ALSSuite._
-    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
-    val als = new ALS()
-    allEstimatorParamSettings.foreach { case (p, v) =>
-      als.set(als.getParam(p), v)
-    }
     val spark = this.spark
     import spark.implicits._
-    val model = als.fit(ratings.toDF())
-
-    // Test Estimator save/load
-    val als2 = testDefaultReadWrite(als)
-    allEstimatorParamSettings.foreach { case (p, v) =>
-      val param = als.getParam(p)
-      assert(als.get(param).get === als2.get(param).get)
-    }
+    import ALSSuite._
+    val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
 
-    // Test Model save/load
-    val model2 = testDefaultReadWrite(model)
-    allModelParamSettings.foreach { case (p, v) =>
-      val param = model.getParam(p)
-      assert(model.get(param).get === model2.get(param).get)
-    }
-    assert(model.rank === model2.rank)
     def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
       df.select("id", "features").collect().map { case r =>
         (r.getInt(0), r.getAs[Array[Float]](1))
       }.toSet
     }
-    assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
-    assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+
+    def checkModelData(model: ALSModel, model2: ALSModel): Unit = {
+      assert(model.rank === model2.rank)
+      assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
+      assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+    }
+
+    val als = new ALS()
+    testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings,
+      allModelParamSettings, checkModelData)
   }
 
   test("input type validation") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index 3cd4b0ac30..708185a094 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite
     }
     val aft = new AFTSurvivalRegression()
     testEstimatorAndModelReadWrite(aft, datasetMultivariate,
-      AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
+      AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings,
+      checkModelData)
   }
 
   test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 15fa26e8b5..0e91284d03 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite
     val categoricalData: DataFrame =
       TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0)
     testEstimatorAndModelReadWrite(dt, categoricalData,
-      TreeTests.allParamSettings, checkModelData)
+      TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 2
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
     testEstimatorAndModelReadWrite(dt, continuousData,
-      TreeTests.allParamSettings, checkModelData)
+      TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData)
 
     // Continuous splits with tree depth 0
     testEstimatorAndModelReadWrite(dt, continuousData,
+      TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
       TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dcf3f9a1ea..03c2f97797 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
     val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared")
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
-    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index add28a72b6..401911763f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite
 
     val glr = new GeneralizedLinearRegression()
     testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
+      GeneralizedLinearRegressionSuite.allParamSettings,
       GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
   }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index 8cbb2acad2..f41a3601b1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -178,7 +178,7 @@ class IsotonicRegressionSuite
 
     val ir = new IsotonicRegression()
     testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
-      checkModelData)
+      IsotonicRegressionSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 584a1b272f..6a51e75e12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -985,7 +985,7 @@ class LinearRegressionSuite
     }
     val lr = new LinearRegression()
     testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
-      checkModelData)
+      LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
   test("should support all NumericType labels and weights, and not support other types") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index c08335f9f8..3bf0445ebd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
 
     val continuousData: DataFrame =
       TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0)
-    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData)
+    testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
+      allParamSettings, checkModelData)
   }
 }
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 553b8725b3..bfe8f12258 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
    *  - Check Params on Estimator and Model
    *  - Compare model data
    *
-   * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
+   * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s.
    *
    * @param estimator  Estimator to test
    * @param dataset  Dataset to pass to [[Estimator.fit()]]
-   * @param testParams  Set of [[Param]] values to set in estimator
+   * @param testEstimatorParams  Set of [[Param]] values to set in estimator
+   * @param testModelParams Set of [[Param]] values to set in model
    * @param checkModelData  Method which takes the original and loaded [[Model]] and compares their
    *                        data.  This method does not need to check [[Param]] values.
    * @tparam E  Type of [[Estimator]]
@@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
     E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
       estimator: E,
       dataset: Dataset[_],
-      testParams: Map[String, Any],
+      testEstimatorParams: Map[String, Any],
+      testModelParams: Map[String, Any],
       checkModelData: (M, M) => Unit): Unit = {
     // Set some Params to make sure set Params are serialized.
-    testParams.foreach { case (p, v) =>
+    testEstimatorParams.foreach { case (p, v) =>
       estimator.set(estimator.getParam(p), v)
     }
     val model = estimator.fit(dataset)
 
     // Test Estimator save/load
     val estimator2 = testDefaultReadWrite(estimator)
-    testParams.foreach { case (p, v) =>
+    testEstimatorParams.foreach { case (p, v) =>
       val param = estimator.getParam(p)
       assert(estimator.get(param).get === estimator2.get(param).get)
     }
 
     // Test Model save/load
     val model2 = testDefaultReadWrite(model)
-    testParams.foreach { case (p, v) =>
+    testModelParams.foreach { case (p, v) =>
       val param = model.getParam(p)
       assert(model.get(param).get === model2.get(param).get)
     }
-- 
GitLab