Skip to content
Snippets Groups Projects
Commit 625cfe09 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by Nick Pentreath
Browse files

[SPARK-19733][ML] Removed unnecessary castings and refactored checked casts in ALS.

## What changes were proposed in this pull request?

The original ALS was performing unnecessary casting to the user and item ids because the protected checkedCast() method required a double. I removed the castings and refactored the method to receive Any and efficiently handle all permitted numeric values.

## How was this patch tested?

I tested it by running the unit-tests and by manually validating the result of checkedCast for various legal and illegal values.

Author: Vasilis Vryniotis <bbriniotis@datumbox.com>

Closes #17059 from datumbox/als_casting_fix.
parent 8d6ef895
No related branches found
No related tags found
No related merge requests found
...@@ -80,14 +80,24 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo ...@@ -80,14 +80,24 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo
/** /**
* Attempts to safely cast a user/item id to an Int. Throws an exception if the value is * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is
* out of integer range. * out of integer range or contains a fractional part.
*/ */
protected val checkedCast = udf { (n: Double) => protected[recommendation] val checkedCast = udf { (n: Any) =>
if (n > Int.MaxValue || n < Int.MinValue) { n match {
throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + case v: Int => v // Avoid unnecessary casting
s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") case v: Number =>
} else { val intV = v.intValue
n.toInt // Checks if number within Int range and has no fractional part.
if (v.doubleValue == intV) {
intV
} else {
throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
s"and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. " +
s"Value $n was either out of Integer range or contained a fractional part that " +
s"could not be converted.")
}
case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " +
s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.")
} }
} }
...@@ -288,9 +298,9 @@ class ALSModel private[ml] ( ...@@ -288,9 +298,9 @@ class ALSModel private[ml] (
} }
val predictions = dataset val predictions = dataset
.join(userFactors, .join(userFactors,
checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") checkedCast(dataset($(userCol))) === userFactors("id"), "left")
.join(itemFactors, .join(itemFactors,
checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") checkedCast(dataset($(itemCol))) === itemFactors("id"), "left")
.select(dataset("*"), .select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
getColdStartStrategy match { getColdStartStrategy match {
...@@ -491,8 +501,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] ...@@ -491,8 +501,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset val ratings = dataset
.select(checkedCast(col($(userCol)).cast(DoubleType)), .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r)
checkedCast(col($(itemCol)).cast(DoubleType)), r)
.rdd .rdd
.map { row => .map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
......
...@@ -40,7 +40,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext ...@@ -40,7 +40,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.types.{FloatType, IntegerType} import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
...@@ -205,6 +206,70 @@ class ALSSuite ...@@ -205,6 +206,70 @@ class ALSSuite
assert(decompressed.toSet === expected) assert(decompressed.toSet === expected)
} }
test("CheckedCast") {
val checkedCast = new ALS().checkedCast
val df = spark.range(1)
withClue("Valid Integer Ids") {
df.select(checkedCast(lit(123))).collect()
}
withClue("Valid Long Ids") {
df.select(checkedCast(lit(1231L))).collect()
}
withClue("Valid Decimal Ids") {
df.select(checkedCast(lit(123).cast(DecimalType(15, 2)))).collect()
}
withClue("Valid Double Ids") {
df.select(checkedCast(lit(123.0))).collect()
}
val msg = "either out of Integer range or contained a fractional part"
withClue("Invalid Long: out of range") {
val e: SparkException = intercept[SparkException] {
df.select(checkedCast(lit(1231000000000L))).collect()
}
assert(e.getMessage.contains(msg))
}
withClue("Invalid Decimal: out of range") {
val e: SparkException = intercept[SparkException] {
df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect()
}
assert(e.getMessage.contains(msg))
}
withClue("Invalid Decimal: fractional part") {
val e: SparkException = intercept[SparkException] {
df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect()
}
assert(e.getMessage.contains(msg))
}
withClue("Invalid Double: out of range") {
val e: SparkException = intercept[SparkException] {
df.select(checkedCast(lit(1231000000000.0))).collect()
}
assert(e.getMessage.contains(msg))
}
withClue("Invalid Double: fractional part") {
val e: SparkException = intercept[SparkException] {
df.select(checkedCast(lit(123.1))).collect()
}
assert(e.getMessage.contains(msg))
}
withClue("Invalid Type") {
val e: SparkException = intercept[SparkException] {
df.select(checkedCast(lit("123.1"))).collect()
}
assert(e.getMessage.contains("was not numeric"))
}
}
/** /**
* Generates an explicit feedback dataset for testing ALS. * Generates an explicit feedback dataset for testing ALS.
* @param numUsers number of users * @param numUsers number of users
...@@ -510,34 +575,35 @@ class ALSSuite ...@@ -510,34 +575,35 @@ class ALSSuite
(0, big, small, 0, big, small, 2.0), (0, big, small, 0, big, small, 2.0),
(1, 1L, 1d, 0, 0L, 0d, 5.0) (1, 1L, 1d, 0, 0L, 0d, 5.0)
).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating")
val msg = "either out of Integer range or contained a fractional part"
withClue("fit should fail when ids exceed integer range. ") { withClue("fit should fail when ids exceed integer range. ") {
assert(intercept[SparkException] { assert(intercept[SparkException] {
als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) als.fit(df.select(df("user_big").as("user"), df("item"), df("rating")))
}.getCause.getMessage.contains("was out of Integer range")) }.getCause.getMessage.contains(msg))
assert(intercept[SparkException] { assert(intercept[SparkException] {
als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) als.fit(df.select(df("user_small").as("user"), df("item"), df("rating")))
}.getCause.getMessage.contains("was out of Integer range")) }.getCause.getMessage.contains(msg))
assert(intercept[SparkException] { assert(intercept[SparkException] {
als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) als.fit(df.select(df("item_big").as("item"), df("user"), df("rating")))
}.getCause.getMessage.contains("was out of Integer range")) }.getCause.getMessage.contains(msg))
assert(intercept[SparkException] { assert(intercept[SparkException] {
als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) als.fit(df.select(df("item_small").as("item"), df("user"), df("rating")))
}.getCause.getMessage.contains("was out of Integer range")) }.getCause.getMessage.contains(msg))
} }
withClue("transform should fail when ids exceed integer range. ") { withClue("transform should fail when ids exceed integer range. ") {
val model = als.fit(df) val model = als.fit(df)
assert(intercept[SparkException] { assert(intercept[SparkException] {
model.transform(df.select(df("user_big").as("user"), df("item"))).first model.transform(df.select(df("user_big").as("user"), df("item"))).first
}.getMessage.contains("was out of Integer range")) }.getMessage.contains(msg))
assert(intercept[SparkException] { assert(intercept[SparkException] {
model.transform(df.select(df("user_small").as("user"), df("item"))).first model.transform(df.select(df("user_small").as("user"), df("item"))).first
}.getMessage.contains("was out of Integer range")) }.getMessage.contains(msg))
assert(intercept[SparkException] { assert(intercept[SparkException] {
model.transform(df.select(df("item_big").as("item"), df("user"))).first model.transform(df.select(df("item_big").as("item"), df("user"))).first
}.getMessage.contains("was out of Integer range")) }.getMessage.contains(msg))
assert(intercept[SparkException] { assert(intercept[SparkException] {
model.transform(df.select(df("item_small").as("item"), df("user"))).first model.transform(df.select(df("item_small").as("item"), df("user"))).first
}.getMessage.contains("was out of Integer range")) }.getMessage.contains(msg))
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment