diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 97c86552986096f61c2549b15fa2d190e23139c4..af007625d1827cc5c3c9da517de643a2c51d69c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -90,6 +90,27 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo n.toInt } } + + /** + * Param for strategy for dealing with unknown or new users/items at prediction time. + * This may be useful in cross-validation or production scenarios, for handling user/item ids + * the model has not seen in the training data. + * Supported values: + * - "nan": predicted value for unknown ids will be NaN. + * - "drop": rows in the input DataFrame containing unknown ids will be dropped from + * the output DataFrame containing predictions. + * Default: "nan". + * @group expertParam + */ + val coldStartStrategy = new Param[String](this, "coldStartStrategy", + "strategy for dealing with unknown or new users/items at prediction time. This may be " + + "useful in cross-validation or production scenarios, for handling user/item ids the model " + + "has not seen in the training data. Supported values: " + + s"${ALSModel.supportedColdStartStrategies.mkString(",")}.", + (s: String) => ALSModel.supportedColdStartStrategies.contains(s.toLowerCase)) + + /** @group expertGetParam */ + def getColdStartStrategy: String = $(coldStartStrategy).toLowerCase } /** @@ -203,7 +224,8 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, - intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK") + intermediateStorageLevel -> "MEMORY_AND_DISK", finalStorageLevel -> "MEMORY_AND_DISK", + coldStartStrategy -> "nan") /** * Validates and transforms the input schema. @@ -248,6 +270,10 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) @@ -260,13 +286,19 @@ class ALSModel private[ml] ( Float.NaN } } - dataset + val predictions = dataset .join(userFactors, checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") .join(itemFactors, checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) + getColdStartStrategy match { + case ALSModel.Drop => + predictions.na.drop("all", Seq($(predictionCol))) + case ALSModel.NaN => + predictions + } } @Since("1.3.0") @@ -290,6 +322,10 @@ class ALSModel private[ml] ( @Since("1.6.0") object ALSModel extends MLReadable[ALSModel] { + private val NaN = "nan" + private val Drop = "drop" + private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop) + @Since("1.6.0") override def read: MLReader[ALSModel] = new ALSModelReader @@ -432,6 +468,10 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] @Since("2.0.0") def setFinalStorageLevel(value: String): this.type = set(finalStorageLevel, value) + /** @group expertSetParam */ + @Since("2.2.0") + def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index b923bacce23ca813defbfc5ad105d29ab1f8b5d5..c9e7b505b2bd2a69f4efee3f67e65c608d452d57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -498,8 +498,8 @@ class ALSSuite (ex, act) => ex.userFactors.first().getSeq[Float](1) === act.userFactors.first.getSeq[Float](1) } { (ex, act, _) => - ex.transform(_: DataFrame).select("prediction").first.getFloat(0) ~== - act.transform(_: DataFrame).select("prediction").first.getFloat(0) absTol 1e-6 + ex.transform(_: DataFrame).select("prediction").first.getDouble(0) ~== + act.transform(_: DataFrame).select("prediction").first.getDouble(0) absTol 1e-6 } } // check user/item ids falling outside of Int range @@ -547,6 +547,53 @@ class ALSSuite ALS.train(ratings) } } + + test("ALS cold start user/item prediction strategy") { + val spark = this.spark + import spark.implicits._ + import org.apache.spark.sql.functions._ + + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) + val data = ratings.toDF + val knownUser = data.select(max("user")).as[Int].first() + val unknownUser = knownUser + 10 + val knownItem = data.select(max("item")).as[Int].first() + val unknownItem = knownItem + 20 + val test = Seq( + (unknownUser, unknownItem), + (knownUser, unknownItem), + (unknownUser, knownItem), + (knownUser, knownItem) + ).toDF("user", "item") + + val als = new ALS().setMaxIter(1).setRank(1) + // default is 'nan' + val defaultModel = als.fit(data) + val defaultPredictions = defaultModel.transform(test).select("prediction").as[Float].collect() + assert(defaultPredictions.length == 4) + assert(defaultPredictions.slice(0, 3).forall(_.isNaN)) + assert(!defaultPredictions.last.isNaN) + + // check 'drop' strategy should filter out rows with unknown users/items + val dropPredictions = defaultModel + .setColdStartStrategy("drop") + .transform(test) + .select("prediction").as[Float].collect() + assert(dropPredictions.length == 1) + assert(!dropPredictions.head.isNaN) + assert(dropPredictions.head ~== defaultPredictions.last relTol 1e-14) + } + + test("case insensitive cold start param value") { + val spark = this.spark + import spark.implicits._ + val (ratings, _) = genExplicitTestData(numUsers = 2, numItems = 2, rank = 1) + val data = ratings.toDF + val model = new ALS().fit(data) + Seq("nan", "NaN", "Nan", "drop", "DROP", "Drop").foreach { s => + model.setColdStartStrategy(s).transform(data) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index e28d38bd19f80b12f0cc66780f548a235f2f1f12..43f82daa9fcd3adea77792aabf6e50712686a118 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -125,19 +125,25 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha finalStorageLevel = Param(Params._dummy(), "finalStorageLevel", "StorageLevel for ALS model factors.", typeConverter=TypeConverters.toString) + coldStartStrategy = Param(Params._dummy(), "coldStartStrategy", "strategy for dealing with " + + "unknown or new users/items at prediction time. This may be useful " + + "in cross-validation or production scenarios, for handling " + + "user/item ids the model has not seen in the training data. " + + "Supported values: 'nan', 'drop'.", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=false, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") """ super(ALS, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.recommendation.ALS", self.uid) @@ -145,7 +151,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -155,13 +161,13 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", - finalStorageLevel="MEMORY_AND_DISK"): + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan"): """ setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \ ratingCol="rating", nonnegative=False, checkpointInterval=10, \ intermediateStorageLevel="MEMORY_AND_DISK", \ - finalStorageLevel="MEMORY_AND_DISK") + finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") Sets params for ALS. """ kwargs = self.setParams._input_kwargs @@ -332,6 +338,20 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ return self.getOrDefault(self.finalStorageLevel) + @since("2.2.0") + def setColdStartStrategy(self, value): + """ + Sets the value of :py:attr:`coldStartStrategy`. + """ + return self._set(coldStartStrategy=value) + + @since("2.2.0") + def getColdStartStrategy(self): + """ + Gets the value of coldStartStrategy or its default value. + """ + return self.getOrDefault(self.coldStartStrategy) + class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """