From 3ca9a2c56513444d7b233088b020d2d43fa6b77a Mon Sep 17 00:00:00 2001 From: Gabor Somogyi <gabor.g.somogyi@gmail.com> Date: Sun, 25 Feb 2018 09:29:59 -0600 Subject: [PATCH] =?UTF-8?q?[SPARK-22886][ML][TESTS]=20ML=20test=20for=20st?= =?UTF-8?q?ructured=20streaming:=20ml.recomme=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Converting spark.ml.recommendation tests to also check code with structured streaming, using the ML testing infrastructure implemented in SPARK-22882. ## How was this patch tested? Automated: Pass the Jenkins. Author: Gabor Somogyi <gabor.g.somogyi@gmail.com> Closes #20362 from gaborgsomogyi/SPARK-22886. --- .../spark/ml/recommendation/ALSSuite.scala | 213 ++++++++++++------ .../apache/spark/ml/util/MLTestingUtils.scala | 44 ---- 2 files changed, 143 insertions(+), 114 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index addcd21d50..e3dfe2faf5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,8 +22,7 @@ import java.util.Random import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.WrappedArray +import scala.collection.mutable.{ArrayBuffer, WrappedArray} import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} @@ -35,21 +34,20 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.recommendation.ALS._ -import org.apache.spark.ml.recommendation.ALS.Rating -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ -import org.apache.spark.mllib.recommendation.MatrixFactorizationModelSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} -import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.{DataFrame, Encoder, Row, SparkSession} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.streaming.StreamingQueryException import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class ALSSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { +class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { override def beforeAll(): Unit = { super.beforeAll() @@ -413,34 +411,36 @@ class ALSSuite .setSeed(0) val alpha = als.getAlpha val model = als.fit(training.toDF()) - val predictions = model.transform(test.toDF()).select("rating", "prediction").rdd.map { - case Row(rating: Float, prediction: Float) => - (rating.toDouble, prediction.toDouble) + testTransformerByGlobalCheckFunc[Rating[Int]](test.toDF(), model, "rating", "prediction") { + case rows: Seq[Row] => + val predictions = rows.map(row => (row.getFloat(0).toDouble, row.getFloat(1).toDouble)) + + val rmse = + if (implicitPrefs) { + // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. + // We limit the ratings and the predictions to interval [0, 1] and compute the + // weighted RMSE with the confidence scores as weights. + val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => + val confidence = 1.0 + alpha * math.abs(rating) + val rating01 = math.max(math.min(rating, 1.0), 0.0) + val prediction01 = math.max(math.min(prediction, 1.0), 0.0) + val err = prediction01 - rating01 + (confidence, confidence * err * err) + }.reduce[(Double, Double)] { case ((c0, e0), (c1, e1)) => + (c0 + c1, e0 + e1) + } + math.sqrt(weightedSumSq / totalWeight) + } else { + val errorSquares = predictions.map { case (rating, prediction) => + val err = rating - prediction + err * err + } + val mse = errorSquares.sum / errorSquares.length + math.sqrt(mse) + } + logInfo(s"Test RMSE is $rmse.") + assert(rmse < targetRMSE) } - val rmse = - if (implicitPrefs) { - // TODO: Use a better (rank-based?) evaluation metric for implicit feedback. - // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE - // with the confidence scores as weights. - val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) => - val confidence = 1.0 + alpha * math.abs(rating) - val rating01 = math.max(math.min(rating, 1.0), 0.0) - val prediction01 = math.max(math.min(prediction, 1.0), 0.0) - val err = prediction01 - rating01 - (confidence, confidence * err * err) - }.reduce { case ((c0, e0), (c1, e1)) => - (c0 + c1, e0 + e1) - } - math.sqrt(weightedSumSq / totalWeight) - } else { - val mse = predictions.map { case (rating, prediction) => - val err = rating - prediction - err * err - }.mean() - math.sqrt(mse) - } - logInfo(s"Test RMSE is $rmse.") - assert(rmse < targetRMSE) MLTestingUtils.checkCopyAndUids(als, model) } @@ -586,6 +586,68 @@ class ALSSuite allModelParamSettings, checkModelData) } + private def checkNumericTypesALS( + estimator: ALS, + spark: SparkSession, + column: String, + baseType: NumericType) + (check: (ALSModel, ALSModel) => Unit) + (check2: (ALSModel, ALSModel, DataFrame, Encoder[_]) => Unit): Unit = { + val dfs = genRatingsDFWithNumericCols(spark, column) + val df = dfs.find { + case (numericTypeWithEncoder, _) => numericTypeWithEncoder.numericType == baseType + } match { + case Some((_, df)) => df + } + val expected = estimator.fit(df) + val actuals = dfs.filter(_ != baseType).map(t => (t, estimator.fit(t._2))) + actuals.foreach { case (_, actual) => check(expected, actual) } + actuals.foreach { case (t, actual) => check2(expected, actual, t._2, t._1.encoder) } + + val baseDF = dfs.find(_._1.numericType == baseType).get._2 + val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) + val cols = Seq(col(column).cast(StringType)) ++ others + val strDF = baseDF.select(cols: _*) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(strDF) + } + assert(thrown.getMessage.contains( + s"$column must be of type NumericType but was actually of type StringType")) + } + + private class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Int, Double)]) + + private def genRatingsDFWithNumericCols( + spark: SparkSession, + column: String) = { + + import testImplicits._ + + val df = spark.createDataFrame(Seq( + (0, 10, 1.0), + (1, 20, 2.0), + (2, 30, 3.0), + (3, 40, 4.0), + (4, 50, 5.0) + )).toDF("user", "item", "rating") + + val others = df.columns.toSeq.diff(Seq(column)).map(col) + val types = + Seq(new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder()) + ) + types.map { t => + val cols = Seq(col(column).cast(t.numericType)) ++ others + t -> df.select(cols: _*) + } + } + test("input type validation") { val spark = this.spark import spark.implicits._ @@ -595,12 +657,16 @@ class ALSSuite val als = new ALS().setMaxIter(1).setRank(1) Seq(("user", IntegerType), ("item", IntegerType), ("rating", FloatType)).foreach { case (colName, sqlType) => - MLTestingUtils.checkNumericTypesALS(als, spark, colName, sqlType) { + checkNumericTypesALS(als, spark, colName, sqlType) { (ex, act) => - ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) - } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== - act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 + ex.userFactors.first().getSeq[Float](1) === act.userFactors.first().getSeq[Float](1) + } { (ex, act, df, enc) => + val expected = ex.transform(df).selectExpr("prediction") + .first().getFloat(0) + testTransformerByGlobalCheckFunc(df, act, "prediction") { + case rows: Seq[Row] => + expected ~== rows.head.getFloat(0) absTol 1e-6 + }(enc) } } // check user/item ids falling outside of Int range @@ -628,18 +694,22 @@ class ALSSuite } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) - assert(intercept[SparkException] { - model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains(msg)) - assert(intercept[SparkException] { - model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains(msg)) + def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { + assert(intercept[SparkException] { + model.transform(dataFrame).first + }.getMessage.contains(msg)) + assert(intercept[StreamingQueryException] { + testTransformer[A](dataFrame, model, "prediction") { _ => } + }.getMessage.contains(msg)) + } + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("user_small").as("user"), + df("item"))) + testTransformIdExceedsIntRange[(Long, Int)](df.select(df("item_big").as("item"), + df("user"))) + testTransformIdExceedsIntRange[(Double, Int)](df.select(df("item_small").as("item"), + df("user"))) } } @@ -662,28 +732,31 @@ class ALSSuite val knownItem = data.select(max("item")).as[Int].first() val unknownItem = knownItem + 20 val test = Seq( - (unknownUser, unknownItem), - (knownUser, unknownItem), - (unknownUser, knownItem), - (knownUser, knownItem) - ).toDF("user", "item") + (unknownUser, unknownItem, true), + (knownUser, unknownItem, true), + (unknownUser, knownItem, true), + (knownUser, knownItem, false) + ).toDF("user", "item", "expectedIsNaN") val als = new ALS().setMaxIter(1).setRank(1) // default is 'nan' val defaultModel = als.fit(data) - val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() - assert(defaultPredictions.length == 4) - assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) - assert(!defaultPredictions.last.isNaN) + testTransformer[(Int, Int, Boolean)](test, defaultModel, "expectedIsNaN", "prediction") { + case Row(expectedIsNaN: Boolean, prediction: Float) => + assert(prediction.isNaN === expectedIsNaN) + } // check 'drop' strategy should filter out rows with unknown users/items - val dropPredictions = defaultModel - .setColdStartStrategy("drop") - .transform(test) - .select("prediction").as[Float].collect() - assert(dropPredictions.length == 1) - assert(!dropPredictions.head.isNaN) - assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + val defaultPrediction = defaultModel.transform(test).select("prediction") + .as[Float].filter(!_.isNaN).first() + testTransformerByGlobalCheckFunc[(Int, Int, Boolean)](test, + defaultModel.setColdStartStrategy("drop"), "prediction") { + case rows: Seq[Row] => + val dropPredictions = rows.map(_.getFloat(0)) + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPrediction relTol 1e-14) + } } test("case insensitive cold start param value") { @@ -693,7 +766,7 @@ class ALSSuite val data = ratings.toDF val model = new ALS().fit(data) Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => - model.setColdStartStrategy(s).transform(data) + testTransformer[Rating[Int]](data, model.setColdStartStrategy(s), "prediction") { _ => } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index aef81c8c17..c328d81b4b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -91,30 +91,6 @@ object MLTestingUtils extends SparkFunSuite { } } - def checkNumericTypesALS( - estimator: ALS, - spark: SparkSession, - column: String, - baseType: NumericType) - (check: (ALSModel, ALSModel) => Unit) - (check2: (ALSModel, ALSModel, DataFrame) => Unit): Unit = { - val dfs = genRatingsDFWithNumericCols(spark, column) - val expected = estimator.fit(dfs(baseType)) - val actuals = dfs.keys.filter(_ != baseType).map(t => (t, estimator.fit(dfs(t)))) - actuals.foreach { case (_, actual) => check(expected, actual) } - actuals.foreach { case (t, actual) => check2(expected, actual, dfs(t)) } - - val baseDF = dfs(baseType) - val others = baseDF.columns.toSeq.diff(Seq(column)).map(col) - val cols = Seq(col(column).cast(StringType)) ++ others - val strDF = baseDF.select(cols: _*) - val thrown = intercept[IllegalArgumentException] { - estimator.fit(strDF) - } - assert(thrown.getMessage.contains( - s"$column must be of type NumericType but was actually of type StringType")) - } - def checkNumericTypes[T <: Evaluator](evaluator: T, spark: SparkSession): Unit = { val dfs = genEvaluatorDFWithNumericLabelCol(spark, "label", "prediction") val expected = evaluator.evaluate(dfs(DoubleType)) @@ -176,26 +152,6 @@ object MLTestingUtils extends SparkFunSuite { }.toMap } - def genRatingsDFWithNumericCols( - spark: SparkSession, - column: String): Map[NumericType, DataFrame] = { - val df = spark.createDataFrame(Seq( - (0, 10, 1.0), - (1, 20, 2.0), - (2, 30, 3.0), - (3, 40, 4.0), - (4, 50, 5.0) - )).toDF("user", "item", "rating") - - val others = df.columns.toSeq.diff(Seq(column)).map(col) - val types: Seq[NumericType] = - Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) - types.map { t => - val cols = Seq(col(column).cast(t)) ++ others - t -> df.select(cols: _*) - }.toMap - } - def genEvaluatorDFWithNumericLabelCol( spark: SparkSession, labelColName: String = "label", -- GitLab