diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index f8bcbeedfb042e5cd653551938e21b237917b31a..1308210417b1bf69c436c81c1b61e054cbf50107 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -1836,52 +1836,24 @@ class LogisticRegressionSuite .forall(x => x(0) >= x(1))) } - test("binary logistic regression with weighted data") { - val numClasses = 2 - val numPoints = 40 - val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark, - numClasses, numPoints) - val testData = Array.tabulate[LabeledPoint](numClasses) { i => - LabeledPoint(i.toDouble, Vectors.dense(i.toDouble)) - }.toSeq.toDF() - val lr = new LogisticRegression().setFamily("binomial").setWeightCol("weight") - val model = lr.fit(outlierData) - val results = model.transform(testData).select("label", "prediction").collect() - - // check that the predictions are the one to one mapping - results.foreach { case Row(label: Double, pred: Double) => - assert(label === pred) + test("logistic regression with sample weights") { + def modelEquals(m1: LogisticRegressionModel, m2: LogisticRegressionModel): Unit = { + assert(m1.coefficientMatrix ~== m2.coefficientMatrix absTol 0.05) + assert(m1.interceptVector ~== m2.interceptVector absTol 0.05) } - val (overSampledData, weightedData) = - MLTestingUtils.genEquivalentOversampledAndWeightedInstances(outlierData, "label", "features", - 42L) - val weightedModel = lr.fit(weightedData) - val overSampledModel = lr.setWeightCol("").fit(overSampledData) - assert(weightedModel.coefficientMatrix ~== overSampledModel.coefficientMatrix relTol 0.01) - } - - test("multinomial logistic regression with weighted data") { - val numClasses = 5 - val numPoints = 40 - val outlierData = MLTestingUtils.genClassificationInstancesWithWeightedOutliers(spark, - numClasses, numPoints) - val testData = Array.tabulate[LabeledPoint](numClasses) { i => - LabeledPoint(i.toDouble, Vectors.dense(i.toDouble)) - }.toSeq.toDF() - val mlr = new LogisticRegression().setFamily("multinomial").setWeightCol("weight") - val model = mlr.fit(outlierData) - val results = model.transform(testData).select("label", "prediction").collect() - - // check that the predictions are the one to one mapping - results.foreach { case Row(label: Double, pred: Double) => - assert(label === pred) + val testParams = Seq( + ("binomial", smallBinaryDataset, 2), + ("multinomial", smallMultinomialDataset, 3) + ) + testParams.foreach { case (family, dataset, numClasses) => + val estimator = new LogisticRegression().setFamily(family) + MLTestingUtils.testArbitrarilyScaledWeights[LogisticRegressionModel, LogisticRegression]( + dataset.as[LabeledPoint], estimator, modelEquals) + MLTestingUtils.testOutliersWithSmallWeights[LogisticRegressionModel, LogisticRegression]( + dataset.as[LabeledPoint], estimator, numClasses, modelEquals) + MLTestingUtils.testOversamplingVsWeighting[LogisticRegressionModel, LogisticRegression]( + dataset.as[LabeledPoint], estimator, modelEquals, seed) } - val (overSampledData, weightedData) = - MLTestingUtils.genEquivalentOversampledAndWeightedInstances(outlierData, "label", "features", - 42L) - val weightedModel = mlr.fit(weightedData) - val overSampledModel = mlr.setWeightCol("").fit(overSampledData) - assert(weightedModel.coefficientMatrix ~== overSampledModel.coefficientMatrix relTol 0.01) } test("set family") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index e934e5ea42b16fe9ce39fa2039e00e1aca312cdc..2a69ef1c3ed01498283e763e1c9b72cfe652ca77 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -38,18 +38,22 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa import testImplicits._ @transient var dataset: Dataset[_] = _ + @transient var bernoulliDataset: Dataset[_] = _ + + private val seed = 42 override def beforeAll(): Unit = { super.beforeAll() - val pi = Array(0.5, 0.1, 0.4).map(math.log) + val pi = Array(0.3, 0.3, 0.4).map(math.log) val theta = Array( - Array(0.70, 0.10, 0.10, 0.10), // label 0 - Array(0.10, 0.70, 0.10, 0.10), // label 1 - Array(0.10, 0.10, 0.70, 0.10) // label 2 + Array(0.30, 0.30, 0.30, 0.30), // label 0 + Array(0.30, 0.30, 0.30, 0.30), // label 1 + Array(0.40, 0.40, 0.40, 0.40) // label 2 ).map(_.map(math.log)) - dataset = generateNaiveBayesInput(pi, theta, 100, 42).toDF() + dataset = generateNaiveBayesInput(pi, theta, 100, seed).toDF() + bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF() } def validatePrediction(predictionAndLabels: DataFrame): Unit = { @@ -139,7 +143,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val theta = new DenseMatrix(3, 4, thetaArray.flatten, true) val testDataset = - generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "multinomial").toDF() + generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF() val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial") val model = nb.fit(testDataset) @@ -157,50 +161,27 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa validateProbabilities(featureAndProbabilities, model, "multinomial") } - test("Naive Bayes Multinomial with weighted samples") { - val nPoints = 1000 - val piArray = Array(0.5, 0.1, 0.4).map(math.log) - val thetaArray = Array( - Array(0.70, 0.10, 0.10, 0.10), // label 0 - Array(0.10, 0.70, 0.10, 0.10), // label 1 - Array(0.10, 0.10, 0.70, 0.10) // label 2 - ).map(_.map(math.log)) - - val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "multinomial").toDF() - val (overSampledData, weightedData) = - MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, - "label", "features", 42L) - val nb = new NaiveBayes().setModelType("multinomial") - val unweightedModel = nb.fit(weightedData) - val overSampledModel = nb.fit(overSampledData) - val weightedModel = nb.setWeightCol("weight").fit(weightedData) - assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) - assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) - assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) - assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) - } - - test("Naive Bayes Bernoulli with weighted samples") { - val nPoints = 10000 - val piArray = Array(0.5, 0.3, 0.2).map(math.log) - val thetaArray = Array( - Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0 - Array(0.02, 0.70, 0.10, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02), // label 1 - Array(0.02, 0.02, 0.60, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.30) // label 2 - ).map(_.map(math.log)) - - val testData = generateNaiveBayesInput(piArray, thetaArray, nPoints, 42, "bernoulli").toDF() - val (overSampledData, weightedData) = - MLTestingUtils.genEquivalentOversampledAndWeightedInstances(testData, - "label", "features", 42L) - val nb = new NaiveBayes().setModelType("bernoulli") - val unweightedModel = nb.fit(weightedData) - val overSampledModel = nb.fit(overSampledData) - val weightedModel = nb.setWeightCol("weight").fit(weightedData) - assert(weightedModel.theta ~== overSampledModel.theta relTol 0.001) - assert(weightedModel.pi ~== overSampledModel.pi relTol 0.001) - assert(unweightedModel.theta !~= overSampledModel.theta relTol 0.001) - assert(unweightedModel.pi !~= overSampledModel.pi relTol 0.001) + test("Naive Bayes with weighted samples") { + val numClasses = 3 + def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = { + assert(m1.pi ~== m2.pi relTol 0.01) + assert(m1.theta ~== m2.theta relTol 0.01) + } + val testParams = Seq( + ("bernoulli", bernoulliDataset), + ("multinomial", dataset) + ) + testParams.foreach { case (family, dataset) => + // NaiveBayes is sensitive to constant scaling of the weights unless smoothing is set to 0 + val estimatorNoSmoothing = new NaiveBayes().setSmoothing(0.0).setModelType(family) + val estimatorWithSmoothing = new NaiveBayes().setModelType(family) + MLTestingUtils.testArbitrarilyScaledWeights[NaiveBayesModel, NaiveBayes]( + dataset.as[LabeledPoint], estimatorNoSmoothing, modelEquals) + MLTestingUtils.testOutliersWithSmallWeights[NaiveBayesModel, NaiveBayes]( + dataset.as[LabeledPoint], estimatorWithSmoothing, numClasses, modelEquals) + MLTestingUtils.testOversamplingVsWeighting[NaiveBayesModel, NaiveBayes]( + dataset.as[LabeledPoint], estimatorWithSmoothing, modelEquals, seed) + } } test("Naive Bayes Bernoulli") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 0be82742a33beaa61430fc1edbbf9d816d5fd369..e05d0c941118ecc599127163e9f80a7f1c470bbb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -36,6 +36,7 @@ class LinearRegressionSuite private val seed: Int = 42 @transient var datasetWithDenseFeature: DataFrame = _ + @transient var datasetWithStrongNoise: DataFrame = _ @transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ @transient var datasetWithSparseFeature: DataFrame = _ @transient var datasetWithWeight: DataFrame = _ @@ -47,6 +48,11 @@ class LinearRegressionSuite datasetWithDenseFeature = sc.parallelize(LinearDataGenerator.generateLinearInput( intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), xVariance = Array(0.7, 1.2), nPoints = 10000, seed, eps = 0.1), 2).map(_.asML).toDF() + + datasetWithStrongNoise = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 100, seed, eps = 5.0), 2).map(_.asML).toDF() + /* datasetWithDenseFeatureWithoutIntercept is not needed for correctness testing but is useful for illustrating training model without intercept @@ -95,6 +101,7 @@ class LinearRegressionSuite Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2).toDF() + datasetWithWeightZeroLabel = sc.parallelize(Seq( Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), @@ -810,91 +817,34 @@ class LinearRegressionSuite } test("linear regression with weighted samples") { - Seq("auto", "l-bfgs", "normal").foreach { solver => - val (data, weightedData) = { - val activeData = LinearDataGenerator.generateLinearInput( - 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1).map(_.asML) - - val rnd = new Random(8392) - val signedData = activeData.map { case p: LabeledPoint => - (rnd.nextGaussian() > 0.0, p) - } - - val data1 = signedData.flatMap { - case (true, p) => Iterator(p, p) - case (false, p) => Iterator(p) - } - - val weightedSignedData = signedData.flatMap { - case (true, LabeledPoint(label, features)) => - Iterator( - Instance(label, weight = 1.2, features), - Instance(label, weight = 0.8, features) - ) - case (false, LabeledPoint(label, features)) => - Iterator( - Instance(label, weight = 0.3, features), - Instance(label, weight = 0.1, features), - Instance(label, weight = 0.6, features) - ) - } - - val noiseData = LinearDataGenerator.generateLinearInput( - 2, Array(1, 3), Array(0.9, -1.3), Array(0.7, 1.2), 500, 1, 0.1).map(_.asML) - val weightedNoiseData = noiseData.map { - case LabeledPoint(label, features) => Instance(label, weight = 0, features) - } - val data2 = weightedSignedData ++ weightedNoiseData - - (sc.parallelize(data1, 4).toDF(), sc.parallelize(data2, 4).toDF()) - } - - val trainer1a = (new LinearRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) - val trainer1b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) - - // Normal optimizer is not supported with non-zero elasticnet parameter. - val model1a0 = trainer1a.fit(data) - val model1a1 = trainer1a.fit(weightedData) - val model1b = trainer1b.fit(weightedData) - - assert(model1a0.coefficients !~= model1a1.coefficients absTol 1E-3) - assert(model1a0.intercept !~= model1a1.intercept absTol 1E-3) - assert(model1a0.coefficients ~== model1b.coefficients absTol 1E-3) - assert(model1a0.intercept ~== model1b.intercept absTol 1E-3) - - val trainer2a = (new LinearRegression).setFitIntercept(true) - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) - val trainer2b = (new LinearRegression).setFitIntercept(true).setWeightCol("weight") - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) - val model2a0 = trainer2a.fit(data) - val model2a1 = trainer2a.fit(weightedData) - val model2b = trainer2b.fit(weightedData) - assert(model2a0.coefficients !~= model2a1.coefficients absTol 1E-3) - assert(model2a0.intercept !~= model2a1.intercept absTol 1E-3) - assert(model2a0.coefficients ~== model2b.coefficients absTol 1E-3) - assert(model2a0.intercept ~== model2b.intercept absTol 1E-3) - - val trainer3a = (new LinearRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) - val trainer3b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(true).setSolver(solver) - val model3a0 = trainer3a.fit(data) - val model3a1 = trainer3a.fit(weightedData) - val model3b = trainer3b.fit(weightedData) - assert(model3a0.coefficients !~= model3a1.coefficients absTol 1E-3) - assert(model3a0.coefficients ~== model3b.coefficients absTol 1E-3) - - val trainer4a = (new LinearRegression).setFitIntercept(false) - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) - val trainer4b = (new LinearRegression).setFitIntercept(false).setWeightCol("weight") - .setElasticNetParam(0.0).setRegParam(0.21).setStandardization(false).setSolver(solver) - val model4a0 = trainer4a.fit(data) - val model4a1 = trainer4a.fit(weightedData) - val model4b = trainer4b.fit(weightedData) - assert(model4a0.coefficients !~= model4a1.coefficients absTol 1E-3) - assert(model4a0.coefficients ~== model4b.coefficients absTol 1E-3) + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + val numClasses = 0 + def modelEquals(m1: LinearRegressionModel, m2: LinearRegressionModel): Unit = { + assert(m1.coefficients ~== m2.coefficients relTol 0.01) + assert(m1.intercept ~== m2.intercept relTol 0.01) + } + val testParams = Seq( + // (elasticNetParam, regParam, fitIntercept, standardization) + (0.0, 0.21, true, true), + (0.0, 0.21, true, false), + (0.0, 0.21, false, false), + (1.0, 0.21, true, true) + ) + + for (solver <- Seq("auto", "l-bfgs", "normal"); + (elasticNetParam, regParam, fitIntercept, standardization) <- testParams) { + val estimator = new LinearRegression() + .setFitIntercept(fitIntercept) + .setStandardization(standardization) + .setRegParam(regParam) + .setElasticNetParam(elasticNetParam) + MLTestingUtils.testArbitrarilyScaledWeights[LinearRegressionModel, LinearRegression]( + datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals) + MLTestingUtils.testOutliersWithSmallWeights[LinearRegressionModel, LinearRegression]( + datasetWithStrongNoise.as[LabeledPoint], estimator, numClasses, modelEquals) + MLTestingUtils.testOversamplingVsWeighting[LinearRegressionModel, LinearRegression]( + datasetWithStrongNoise.as[LabeledPoint], estimator, modelEquals, seed) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index 472a5af06e7a2852aed90e4f57461a3dde8d413b..d219c428189240913a544768e54589bb0ea098e6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -18,15 +18,15 @@ package org.apache.spark.ml.util import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol, HasWeightCol} import org.apache.spark.ml.recommendation.{ALS, ALSModel} import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -182,46 +182,79 @@ object MLTestingUtils extends SparkFunSuite { .toMap } - def genClassificationInstancesWithWeightedOutliers( - spark: SparkSession, - numClasses: Int, - numInstances: Int): DataFrame = { - val data = Array.tabulate[Instance](numInstances) { i => - val feature = i % numClasses - if (i < numInstances / 3) { - // give large weights to minority of data with 1 to 1 mapping feature to label - Instance(feature, 1.0, Vectors.dense(feature)) - } else { - // give small weights to majority of data points with reverse mapping - Instance(numClasses - feature - 1, 0.01, Vectors.dense(feature)) - } - } - val labelMeta = - NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses).toMetadata() - spark.createDataFrame(data).select(col("label").as("label", labelMeta), col("weight"), - col("features")) - } - + /** + * Given a DataFrame, generate two output DataFrames: one having the original rows oversampled + * an integer number of times, and one having the original rows but with a column of weights + * proportional to the number of oversampled instances in the oversampled DataFrames. + */ def genEquivalentOversampledAndWeightedInstances( - data: DataFrame, - labelCol: String, - featuresCol: String, - seed: Long): (DataFrame, DataFrame) = { + data: Dataset[LabeledPoint], + seed: Long): (Dataset[Instance], Dataset[Instance]) = { import data.sparkSession.implicits._ - val rng = scala.util.Random - rng.setSeed(seed) + val rng = new scala.util.Random(seed) val sample: () => Int = () => rng.nextInt(10) + 1 val sampleUDF = udf(sample) - val rawData = data.select(labelCol, featuresCol).withColumn("samples", sampleUDF()) - val overSampledData = rawData.rdd.flatMap { - case Row(label: Double, features: Vector, n: Int) => - Iterator.fill(n)(Instance(label, 1.0, features)) - }.toDF() + val rawData = data.select("label", "features").withColumn("samples", sampleUDF()) + val overSampledData = rawData.rdd.flatMap { case Row(label: Double, features: Vector, n: Int) => + Iterator.fill(n)(Instance(label, 1.0, features)) + }.toDS() rng.setSeed(seed) - val weightedData = rawData.rdd.map { - case Row(label: Double, features: Vector, n: Int) => - Instance(label, n.toDouble, features) - }.toDF() + val weightedData = rawData.rdd.map { case Row(label: Double, features: Vector, n: Int) => + Instance(label, n.toDouble, features) + }.toDS() (overSampledData, weightedData) } + + /** + * Helper function for testing sample weights. Tests that oversampling each point is equivalent + * to assigning a sample weight proportional to the number of samples for each point. + */ + def testOversamplingVsWeighting[M <: Model[M], E <: Estimator[M]]( + data: Dataset[LabeledPoint], + estimator: E with HasWeightCol, + modelEquals: (M, M) => Unit, + seed: Long): Unit = { + val (overSampledData, weightedData) = genEquivalentOversampledAndWeightedInstances( + data, seed) + val weightedModel = estimator.set(estimator.weightCol, "weight").fit(weightedData) + val overSampledModel = estimator.set(estimator.weightCol, "").fit(overSampledData) + modelEquals(weightedModel, overSampledModel) + } + + /** + * Helper function for testing sample weights. Tests that injecting a large number of outliers + * with very small sample weights does not affect fitting. The predictor should learn the true + * model despite the outliers. + */ + def testOutliersWithSmallWeights[M <: Model[M], E <: Estimator[M]]( + data: Dataset[LabeledPoint], + estimator: E with HasWeightCol, + numClasses: Int, + modelEquals: (M, M) => Unit): Unit = { + import data.sqlContext.implicits._ + val outlierDS = data.withColumn("weight", lit(1.0)).as[Instance].flatMap { + case Instance(l, w, f) => + val outlierLabel = if (numClasses == 0) -l else numClasses - l - 1 + List.fill(3)(Instance(outlierLabel, 0.0001, f)) ++ List(Instance(l, w, f)) + } + val trueModel = estimator.set(estimator.weightCol, "").fit(data) + val outlierModel = estimator.set(estimator.weightCol, "weight").fit(outlierDS) + modelEquals(trueModel, outlierModel) + } + + /** + * Helper function for testing sample weights. Tests that giving constant weights to each data + * point yields the same model, regardless of the magnitude of the weight. + */ + def testArbitrarilyScaledWeights[M <: Model[M], E <: Estimator[M]]( + data: Dataset[LabeledPoint], + estimator: E with HasWeightCol, + modelEquals: (M, M) => Unit): Unit = { + estimator.set(estimator.weightCol, "weight") + val models = Seq(0.001, 1.0, 1000.0).map { w => + val df = data.withColumn("weight", lit(w)) + estimator.fit(df) + } + models.sliding(2).foreach { case Seq(m1, m2) => modelEquals(m1, m2)} + } }