From 1fa58868bc6635ff2119264665bd3d00b4b1253a Mon Sep 17 00:00:00 2001 From: Yanbo Liang <ybliang8@gmail.com> Date: Wed, 8 Mar 2017 02:05:01 -0800 Subject: [PATCH] [ML][MINOR] Separate estimator and model params for read/write test. ## What changes were proposed in this pull request? Since we allow ```Estimator``` and ```Model``` not always share same params (see ```ALSParams``` and ```ALSModelParams```), we should pass in test params for estimator and model separately in function ```testEstimatorAndModelReadWrite```. ## How was this patch tested? Existing tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #17151 from yanboliang/test-rw. --- .../DecisionTreeClassifierSuite.scala | 8 +++-- .../classification/GBTClassifierSuite.scala | 3 +- .../ml/classification/LinearSVCSuite.scala | 2 +- .../LogisticRegressionSuite.scala | 2 +- .../ml/classification/NaiveBayesSuite.scala | 3 +- .../RandomForestClassifierSuite.scala | 3 +- .../ml/clustering/BisectingKMeansSuite.scala | 4 +-- .../ml/clustering/GaussianMixtureSuite.scala | 2 +- .../spark/ml/clustering/KMeansSuite.scala | 3 +- .../apache/spark/ml/clustering/LDASuite.scala | 4 ++- .../BucketedRandomProjectionLSHSuite.scala | 2 +- .../spark/ml/feature/ChiSqSelectorSuite.scala | 3 +- .../spark/ml/feature/MinHashLSHSuite.scala | 2 +- .../apache/spark/ml/fpm/FPGrowthSuite.scala | 4 +-- .../spark/ml/recommendation/ALSSuite.scala | 35 +++++++------------ .../AFTSurvivalRegressionSuite.scala | 3 +- .../DecisionTreeRegressorSuite.scala | 5 +-- .../ml/regression/GBTRegressorSuite.scala | 3 +- .../GeneralizedLinearRegressionSuite.scala | 1 + .../regression/IsotonicRegressionSuite.scala | 2 +- .../ml/regression/LinearRegressionSuite.scala | 2 +- .../RandomForestRegressorSuite.scala | 3 +- .../spark/ml/util/DefaultReadWriteTest.scala | 14 ++++---- 23 files changed, 59 insertions(+), 54 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c711e7fa9d..10de50306a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite // Categorical splits with tree depth 2 val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) - testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), - checkModelData) + allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0598943c3d..0cddb37281 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index fe47176a4a..4c63a2a88c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -232,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } val svm = new LinearSVC() testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings, - checkModelData) + LinearSVCSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index d89a958eed..affaa57374 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2089,7 +2089,7 @@ class LogisticRegressionSuite } val lr = new LogisticRegression() testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings, - checkModelData) + LogisticRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 37d7991fe8..4d5d299d14 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(model.theta === model2.theta) } val nb = new NaiveBayes() - testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, + NaiveBayesSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 44e1585ee5..c3003cec73 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -218,7 +218,8 @@ class RandomForestClassifierSuite val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 30513c1e27..200a892f6c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -138,8 +138,8 @@ class BisectingKMeansSuite assert(model.clusterCenters === model2.clusterCenters) } val bisectingKMeans = new BisectingKMeans() - testEstimatorAndModelReadWrite( - bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, + BisectingKMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index c500c5b3e3..61da897b66 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov)) } val gm = new GaussianMixture() - testEstimatorAndModelReadWrite(gm, dataset, + testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings, GaussianMixtureSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e10127f7d1..ca05b9c389 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new KMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, + KMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 9aa11fbdbe..75aa0be61a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(model2.getDocConcentration) absTol 1e-6) } val lda = new LDA() - testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, + LDASuite.allParamSettings, checkModelData) } test("read/write DistributedLDAModel") { @@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ab937685a5..91eac9e733 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite } val mh = new BucketedRandomProjectionLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0) - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 482e5d5426..d6925da97d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.selectedFeatures === model2.selectedFeatures) } val nb = new ChiSqSelector - testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, + ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3461cdf824..a2f009310f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } val mh = new MinHashLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values") - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 74c7461401..076d55c180 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul model2.freqItemsets.sort("items").collect()) } val fPGrowth = new FPGrowth() - testEstimatorAndModelReadWrite( - fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, + FPGrowthSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e494ea89e6..a177ed13bf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -518,37 +518,26 @@ class ALSSuite } test("read/write") { - import ALSSuite._ - val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - val als = new ALS() - allEstimatorParamSettings.foreach { case (p, v) => - als.set(als.getParam(p), v) - } val spark = this.spark import spark.implicits._ - val model = als.fit(ratings.toDF()) - - // Test Estimator save/load - val als2 = testDefaultReadWrite(als) - allEstimatorParamSettings.foreach { case (p, v) => - val param = als.getParam(p) - assert(als.get(param).get === als2.get(param).get) - } + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - // Test Model save/load - val model2 = testDefaultReadWrite(model) - allModelParamSettings.foreach { case (p, v) => - val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) - } - assert(model.rank === model2.rank) def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { df.select("id", "features").collect().map { case r => (r.getInt(0), r.getAs[Array[Float]](1)) }.toSet } - assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) - assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + + def checkModelData(model: ALSModel, model2: ALSModel): Unit = { + assert(model.rank === model2.rank) + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } + + val als = new ALS() + testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings, + allModelParamSettings, checkModelData) } test("input type validation") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 3cd4b0ac30..708185a094 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite } val aft = new AFTSurvivalRegression() testEstimatorAndModelReadWrite(aft, datasetMultivariate, - AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings, + checkModelData) } test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 15fa26e8b5..0e91284d03 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) testEstimatorAndModelReadWrite(dt, categoricalData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) testEstimatorAndModelReadWrite(dt, continuousData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dcf3f9a1ea..03c2f97797 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index add28a72b6..401911763f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression() testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 8cbb2acad2..f41a3601b1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -178,7 +178,7 @@ class IsotonicRegressionSuite val ir = new IsotonicRegression() testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, - checkModelData) + IsotonicRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 584a1b272f..6a51e75e12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -985,7 +985,7 @@ class LinearRegressionSuite } val lr = new LinearRegression() testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, - checkModelData) + LinearRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c08335f9f8..3bf0445ebd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b3..bfe8f12258 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Check Params on Estimator and Model * - Compare model data * - * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s. * * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testParams Set of [[Param]] values to set in estimator + * @param testEstimatorParams Set of [[Param]] values to set in estimator + * @param testModelParams Set of [[Param]] values to set in model * @param checkModelData Method which takes the original and loaded [[Model]] and compares their * data. This method does not need to check [[Param]] values. * @tparam E Type of [[Estimator]] @@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: Dataset[_], - testParams: Map[String, Any], + testEstimatorParams: Map[String, Any], + testModelParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized. - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) } val model = estimator.fit(dataset) // Test Estimator save/load val estimator2 = testDefaultReadWrite(estimator) - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => val param = estimator.getParam(p) assert(estimator.get(param).get === estimator2.get(param).get) } // Test Model save/load val model2 = testDefaultReadWrite(model) - testParams.foreach { case (p, v) => + testModelParams.foreach { case (p, v) => val param = model.getParam(p) assert(model.get(param).get === model2.get(param).get) } -- GitLab