diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 1562bf1beb7e1d82395a9f612842035fb4c4746a..d626f04599670a5e0af71570c608fbe62e24fa5c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -45,7 +45,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.Utils +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom @@ -356,6 +356,19 @@ class ALSModel private[ml] ( /** * Makes recommendations for all users (or items). + * + * Note: the previous approach used for computing top-k recommendations + * used a cross-join followed by predicting a score for each row of the joined dataset. + * However, this results in exploding the size of intermediate data. While Spark SQL makes it + * relatively efficient, the approach implemented here is significantly more efficient. + * + * This approach groups factors into blocks and computes the top-k elements per block, + * using a simple dot product (instead of gemm) and an efficient [[BoundedPriorityQueue]]. + * It then computes the global top-k by aggregating the per block top-k elements with + * a [[TopByKeyAggregator]]. This significantly reduces the size of intermediate and shuffle data. + * This is the DataFrame equivalent to the approach used in + * [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel]]. + * * @param srcFactors src factors for which to generate recommendations * @param dstFactors dst factors used to make recommendations * @param srcOutputColumn name of the column for the source ID in the output DataFrame @@ -372,11 +385,43 @@ class ALSModel private[ml] ( num: Int): DataFrame = { import srcFactors.sparkSession.implicits._ - val ratings = srcFactors.crossJoin(dstFactors) - .select( - srcFactors("id"), - dstFactors("id"), - predict(srcFactors("features"), dstFactors("features"))) + val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])]) + val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])]) + val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked) + .as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])] + .flatMap { case (srcIter, dstIter) => + val m = srcIter.size + val n = math.min(dstIter.size, num) + val output = new Array[(Int, Int, Float)](m * n) + var j = 0 + val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2)) + srcIter.foreach { case (srcId, srcFactor) => + dstIter.foreach { case (dstId, dstFactor) => + /* + * The below code is equivalent to + * `val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)` + * This handwritten version is as or more efficient as BLAS calls in this case. + */ + var score = 0.0f + var k = 0 + while (k < rank) { + score += srcFactor(k) * dstFactor(k) + k += 1 + } + pq += dstId -> score + } + val pqIter = pq.iterator + var i = 0 + while (i < n) { + val (dstId, score) = pqIter.next() + output(j + i) = (srcId, dstId, score) + i += 1 + } + j += n + pq.clear() + } + output.toSeq + } // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) @@ -387,8 +432,20 @@ class ALSModel private[ml] ( .add(dstOutputColumn, IntegerType) .add("rating", FloatType) ) - recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) + recs.select($"id".as(srcOutputColumn), $"recommendations".cast(arrayType)) } + + /** + * Blockifies factors to improve the efficiency of cross join + * TODO: SPARK-20443 - expose blockSize as a param? + */ + private def blockify( + factors: Dataset[(Int, Array[Float])], + blockSize: Int = 4096): Dataset[Seq[(Int, Array[Float])]] = { + import factors.sparkSession.implicits._ + factors.mapPartitions(_.grouped(blockSize)) + } + } @Since("1.6.0")