diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 04273a40d92a4ab72623432895d0dc5645db9a07..799e881fad74afde680991cf03fe702afc970d9c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -80,14 +80,24 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo /** * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is - * out of integer range. + * out of integer range or contains a fractional part. */ - protected val checkedCast = udf { (n: Double) => - if (n > Int.MaxValue || n < Int.MinValue) { - throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + - s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") - } else { - n.toInt + protected[recommendation] val checkedCast = udf { (n: Any) => + n match { + case v: Int => v // Avoid unnecessary casting + case v: Number => + val intV = v.intValue + // Checks if number within Int range and has no fractional part. + if (v.doubleValue == intV) { + intV + } else { + throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. " + + s"Value $n was either out of Integer range or contained a fractional part that " + + s"could not be converted.") + } + case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.") } } @@ -288,9 +298,9 @@ class ALSModel private[ml] ( } val predictions = dataset .join(userFactors, - checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") + checkedCast(dataset($(userCol))) === userFactors("id"), "left") .join(itemFactors, - checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") + checkedCast(dataset($(itemCol))) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) getColdStartStrategy match { @@ -491,8 +501,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(checkedCast(col($(userCol)).cast(DoubleType)), - checkedCast(col($(itemCol)).cast(DoubleType)), r) + .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r) .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index c9e7b505b2bd2a69f4efee3f67e65c608d452d57..c8228dd004374a7536abab51708b1022d1da09d3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -40,7 +40,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.types.{FloatType, IntegerType} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -205,6 +206,70 @@ class ALSSuite assert(decompressed.toSet === expected) } + test("CheckedCast") { + val checkedCast = new ALS().checkedCast + val df = spark.range(1) + + withClue("Valid Integer Ids") { + df.select(checkedCast(lit(123))).collect() + } + + withClue("Valid Long Ids") { + df.select(checkedCast(lit(1231L))).collect() + } + + withClue("Valid Decimal Ids") { + df.select(checkedCast(lit(123).cast(DecimalType(15, 2)))).collect() + } + + withClue("Valid Double Ids") { + df.select(checkedCast(lit(123.0))).collect() + } + + val msg = "either out of Integer range or contained a fractional part" + withClue("Invalid Long: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000L))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Type") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit("123.1"))).collect() + } + assert(e.getMessage.contains("was not numeric")) + } + } + /** * Generates an explicit feedback dataset for testing ALS. * @param numUsers number of users @@ -510,34 +575,35 @@ class ALSSuite (0, big, small, 0, big, small, 2.0), (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + val msg = "either out of Integer range or contained a fractional part" withClue("fit should fail when ids exceed integer range. ") { assert(intercept[SparkException] { als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) assert(intercept[SparkException] { model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) } }