diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 799e881fad74afde680991cf03fe702afc970d9c..60dd7367053e2e28155a32eca18e9abfeb016b3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,8 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -284,18 +285,20 @@ class ALSModel private[ml] ( @Since("2.2.0") def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) => + if (featuresA != null && featuresB != null) { + // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for + // potential optimization. + blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1) + } else { + Float.NaN + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. - val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => - if (userFeatures != null && itemFeatures != null) { - blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) - } else { - Float.NaN - } - } val predictions = dataset .join(userFactors, checkedCast(dataset($(userCol))) === userFactors("id"), "left") @@ -327,6 +330,64 @@ class ALSModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new ALSModel.ALSModelWriter(this) + + /** + * Returns top `numItems` items recommended for each user, for all users. + * @param numItems max number of recommendations for each user + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllUsers(numItems: Int): DataFrame = { + recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) + } + + /** + * Returns top `numUsers` users recommended for each item, for all items. + * @param numUsers max number of recommendations for each item + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllItems(numUsers: Int): DataFrame = { + recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Makes recommendations for all users (or items). + * @param srcFactors src factors for which to generate recommendations + * @param dstFactors dst factors used to make recommendations + * @param srcOutputColumn name of the column for the source ID in the output DataFrame + * @param dstOutputColumn name of the column for the destination ID in the output DataFrame + * @param num max number of recommendations for each record + * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are + * stored as an array of (dstOutputColumn: Int, rating: Float) Rows. + */ + private def recommendForAll( + srcFactors: DataFrame, + dstFactors: DataFrame, + srcOutputColumn: String, + dstOutputColumn: String, + num: Int): DataFrame = { + import srcFactors.sparkSession.implicits._ + + val ratings = srcFactors.crossJoin(dstFactors) + .select( + srcFactors("id"), + dstFactors("id"), + predict(srcFactors("features"), dstFactors("features"))) + // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) + val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) + .toDF("id", "recommendations") + + val arrayType = ArrayType( + new StructType() + .add(dstOutputColumn, IntegerType) + .add("rating", FloatType) + ) + recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) + } } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala new file mode 100644 index 0000000000000000000000000000000000000000..517179c0eb9aea2032be5c6a7f4e19bf9d63efb5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.{Encoder, Encoders} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.util.BoundedPriorityQueue + + +/** + * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds + * the top `num` K2 items based on the given Ordering. + */ +private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] + (num: Int, ord: Ordering[(K2, V)]) + extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] { + + override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord) + + override def reduce( + q: BoundedPriorityQueue[(K2, V)], + a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = { + q += {(a._2, a._3)} + } + + override def merge( + q1: BoundedPriorityQueue[(K2, V)], + q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = { + q1 ++= q2 + } + + override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = { + r.toArray.sorted(ord.reverse) + } + + override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = { + Encoders.kryo[BoundedPriorityQueue[(K2, V)]] + } + + override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]() +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index c8228dd004374a7536abab51708b1022d1da09d3..e494ea89e63bd1c56606a045e32a31ceebb92017 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.WrappedArray import scala.collection.JavaConverters._ import scala.language.existentials @@ -660,6 +661,99 @@ class ALSSuite model.setColdStartStrategy(s).transform(data) } } + + private def getALSModel = { + val spark = this.spark + import spark.implicits._ + + val userFactors = Seq( + (0, Array(6.0f, 4.0f)), + (1, Array(3.0f, 4.0f)), + (2, Array(3.0f, 6.0f)) + ).toDF("id", "features") + val itemFactors = Seq( + (3, Array(5.0f, 6.0f)), + (4, Array(6.0f, 2.0f)), + (5, Array(3.0f, 6.0f)), + (6, Array(4.0f, 1.0f)) + ).toDF("id", "features") + val als = new ALS().setRank(2) + new ALSModel(als.uid, als.getRank, userFactors, itemFactors) + .setUserCol("user") + .setItemCol("item") + } + + test("recommendForAllUsers with k < num_items") { + val topItems = getALSModel.recommendForAllUsers(2) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f), (5, 33f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllUsers with k = num_items") { + val topItems = getALSModel.recommendForAllUsers(4) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllItems with k < num_users") { + val topUsers = getALSModel.recommendForAllItems(2) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f)), + 4 -> Array((0, 44f), (2, 30f)), + 5 -> Array((2, 45f), (0, 42f)), + 6 -> Array((0, 28f), (2, 18f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + test("recommendForAllItems with k = num_users") { + val topUsers = getALSModel.recommendForAllItems(3) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f), (1, 39f)), + 4 -> Array((0, 44f), (2, 30f), (1, 26f)), + 5 -> Array((2, 45f), (0, 42f), (1, 33f)), + 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + private def checkRecommendations( + topK: DataFrame, + expected: Map[Int, Array[(Int, Float)]], + dstColName: String): Unit = { + val spark = this.spark + import spark.implicits._ + + assert(topK.columns.contains("recommendations")) + topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) => + assert(recs === expected(id)) + } + topK.collect().foreach { row => + val recs = row.getAs[WrappedArray[Row]]("recommendations") + assert(recs(0).fieldIndex(dstColName) == 0) + assert(recs(0).fieldIndex("rating") == 1) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..5e763a8e908b8e8da1eb6f7501ed34ff123b97a6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + + +class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2)) + Seq( + (0, 3, 54f), + (0, 4, 44f), + (0, 5, 42f), + (0, 6, 28f), + (1, 3, 39f), + (2, 3, 51f), + (2, 5, 45f), + (2, 6, 18f) + ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn) + } + + test("topByKey with k < #items") { + val topK = getTopK(2) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkTopK(topK, expected) + } + + test("topByKey with k > #items") { + val topK = getTopK(5) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f), (6, 18f)) + ) + checkTopK(topK, expected) + } + + private def checkTopK( + topK: Dataset[(Int, Array[(Int, Float)])], + expected: Map[Int, Array[(Int, Float)]]): Unit = { + topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) } + } +}