diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c711e7fa9dc677983a1d8bf8f0db953d0ad997f6..10de50306a5ce4e488d633dcde7bbe4024e8ea85 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite // Categorical splits with tree depth 2 val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) - testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), - checkModelData) + allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0598943c3d4be0ca0dae9669032f76a8ca72fd39..0cddb37281b39269d917baa1fac75e7904a1925c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index fe47176a4aaa6f06c4b0758687ac53498289a090..4c63a2a88c6c6fdd6628f3c14a6698b1e8e03eb4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -232,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } val svm = new LinearSVC() testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings, - checkModelData) + LinearSVCSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index d89a958eed45adaeaf1f7e1e63ab5d09f9278188..affaa573749e8f90ab0a582c098efb332a2c9f70 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2089,7 +2089,7 @@ class LogisticRegressionSuite } val lr = new LogisticRegression() testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings, - checkModelData) + LogisticRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 37d7991fe8dd851584830d0249472e1861b3f69a..4d5d299d1408f1ed309d4c9f8e28f5fd020c9667 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(model.theta === model2.theta) } val nb = new NaiveBayes() - testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, + NaiveBayesSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 44e1585ee514b48d2882e90dff9f28696d3f73a3..c3003cec73b4166bf4c4af2c10029e678f5107a4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -218,7 +218,8 @@ class RandomForestClassifierSuite val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 30513c1e276aec28c3ce4b9db4d20a369192539d..200a892f6c694bbc8b1bc5270d19e2fa59185936 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -138,8 +138,8 @@ class BisectingKMeansSuite assert(model.clusterCenters === model2.clusterCenters) } val bisectingKMeans = new BisectingKMeans() - testEstimatorAndModelReadWrite( - bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, + BisectingKMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index c500c5b3e365a5e0f162e0510f18401efd51dbf6..61da897b666f4b97084d62524563c15428d8ac04 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov)) } val gm = new GaussianMixture() - testEstimatorAndModelReadWrite(gm, dataset, + testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings, GaussianMixtureSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e10127f7d108f3e9c31a02981c554b8d1fab6aa4..ca05b9c389f656e2e8d8cbd087d7f1ae776b9162 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new KMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, + KMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 9aa11fbdbe86802f818dc4b8b7960f76a2c46d4d..75aa0be61a3ed29a839033ece8f1135af10c2488 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(model2.getDocConcentration) absTol 1e-6) } val lda = new LDA() - testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, + LDASuite.allParamSettings, checkModelData) } test("read/write DistributedLDAModel") { @@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ab937685a555c074430d6b47caa33bc955779fba..91eac9e73331255c5393b3884ae72bdf7091eb20 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite } val mh = new BucketedRandomProjectionLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0) - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 482e5d54260d4bf8a7c97e7d955df22b6710a489..d6925da97d57e603e44c4d5af6158d6b4396cefe 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.selectedFeatures === model2.selectedFeatures) } val nb = new ChiSqSelector - testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, + ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3461cdf82460f0398269a796ca0cde4d5b9c601e..a2f009310fd7a53ebd1bfe47f8468e0ebe462e9a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } val mh = new MinHashLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values") - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 74c746140190561b522828ee1fd6f25cea295884..076d55c180548c622b369a86cb2019d3cd05f745 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul model2.freqItemsets.sort("items").collect()) } val fPGrowth = new FPGrowth() - testEstimatorAndModelReadWrite( - fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, + FPGrowthSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e494ea89e63bd1c56606a045e32a31ceebb92017..a177ed13bf8efb06e55e1594ada20316d6ea21c8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -518,37 +518,26 @@ class ALSSuite } test("read/write") { - import ALSSuite._ - val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - val als = new ALS() - allEstimatorParamSettings.foreach { case (p, v) => - als.set(als.getParam(p), v) - } val spark = this.spark import spark.implicits._ - val model = als.fit(ratings.toDF()) - - // Test Estimator save/load - val als2 = testDefaultReadWrite(als) - allEstimatorParamSettings.foreach { case (p, v) => - val param = als.getParam(p) - assert(als.get(param).get === als2.get(param).get) - } + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - // Test Model save/load - val model2 = testDefaultReadWrite(model) - allModelParamSettings.foreach { case (p, v) => - val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) - } - assert(model.rank === model2.rank) def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { df.select("id", "features").collect().map { case r => (r.getInt(0), r.getAs[Array[Float]](1)) }.toSet } - assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) - assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + + def checkModelData(model: ALSModel, model2: ALSModel): Unit = { + assert(model.rank === model2.rank) + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } + + val als = new ALS() + testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings, + allModelParamSettings, checkModelData) } test("input type validation") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 3cd4b0ac308efae7952d09041f3bd463e6542b47..708185a0943df89b5e606f73a64a25e5af8db654 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite } val aft = new AFTSurvivalRegression() testEstimatorAndModelReadWrite(aft, datasetMultivariate, - AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings, + checkModelData) } test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 15fa26e8b527288628248ce2b36162757b105523..0e91284d03d983a9c5e091461a8f4240ed097a65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) testEstimatorAndModelReadWrite(dt, categoricalData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) testEstimatorAndModelReadWrite(dt, continuousData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dcf3f9a1ea9b263ab24d3f8c918d0a12260c49e7..03c2f97797bce26992669dfc5bd8cc857c45df26 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index add28a72b6808fe3f3d22c418f1a125ec1edcb65..401911763fa3bd0eaab94b20f100ab8a198bfac5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression() testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 8cbb2acad243e092e95363a19d809fe4957f6a5c..f41a3601b1fa8c716f7aaa0019c0fc34ab15b3f1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -178,7 +178,7 @@ class IsotonicRegressionSuite val ir = new IsotonicRegression() testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, - checkModelData) + IsotonicRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 584a1b272f6c806d93efc70843226220e607a224..6a51e75e12a36f84065552c6712814f336ef29f8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -985,7 +985,7 @@ class LinearRegressionSuite } val lr = new LinearRegression() testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, - checkModelData) + LinearRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c08335f9f84afec87c9af29524ae19838ac019ae..3bf0445ebd3dd8b073637cd36d95e39bb85e8916 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b30a31f4f232ef042f6adfbe56dce30f..bfe8f12258bb8539bed0869bdba9d0ee35006495 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Check Params on Estimator and Model * - Compare model data * - * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s. * * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testParams Set of [[Param]] values to set in estimator + * @param testEstimatorParams Set of [[Param]] values to set in estimator + * @param testModelParams Set of [[Param]] values to set in model * @param checkModelData Method which takes the original and loaded [[Model]] and compares their * data. This method does not need to check [[Param]] values. * @tparam E Type of [[Estimator]] @@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: Dataset[_], - testParams: Map[String, Any], + testEstimatorParams: Map[String, Any], + testModelParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized. - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) } val model = estimator.fit(dataset) // Test Estimator save/load val estimator2 = testDefaultReadWrite(estimator) - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => val param = estimator.getParam(p) assert(estimator.get(param).get === estimator2.get(param).get) } // Test Model save/load val model2 = testDefaultReadWrite(model) - testParams.foreach { case (p, v) => + testModelParams.foreach { case (p, v) => val param = model.getParam(p) assert(model.get(param).get === model2.get(param).get) }