From 58518d070777fc0665c4d02bad8adf910807df98 Mon Sep 17 00:00:00 2001
From: Nick Pentreath <nickp@za.ibm.com>
Date: Mon, 8 May 2017 12:45:00 +0200
Subject: [PATCH] [SPARK-20596][ML][TEST] Consolidate and improve ALS
 recommendAll test cases

Existing test cases for `recommendForAllX` methods (added in [SPARK-19535](https://issues.apache.org/jira/browse/SPARK-19535)) test `k < num items` and `k = num items`. Technically we should also test that `k > num items` returns the same results as `k = num items`.

## How was this patch tested?

Updated existing unit tests.

Author: Nick Pentreath <nickp@za.ibm.com>

Closes #17860 from MLnick/SPARK-20596-als-rec-tests.
---
 .../spark/ml/recommendation/ALSSuite.scala    | 63 ++++++++-----------
 1 file changed, 25 insertions(+), 38 deletions(-)

diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index 7574af3d77..9d31e79263 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -671,58 +671,45 @@ class ALSSuite
       .setItemCol("item")
   }
 
-  test("recommendForAllUsers with k < num_items") {
-    val topItems = getALSModel.recommendForAllUsers(2)
-    assert(topItems.count() == 3)
-    assert(topItems.columns.contains("user"))
-
-    val expected = Map(
-      0 -> Array((3, 54f), (4, 44f)),
-      1 -> Array((3, 39f), (5, 33f)),
-      2 -> Array((3, 51f), (5, 45f))
-    )
-    checkRecommendations(topItems, expected, "item")
-  }
-
-  test("recommendForAllUsers with k = num_items") {
-    val topItems = getALSModel.recommendForAllUsers(4)
-    assert(topItems.count() == 3)
-    assert(topItems.columns.contains("user"))
-
+  test("recommendForAllUsers with k <, = and > num_items") {
+    val model = getALSModel
+    val numUsers = model.userFactors.count
+    val numItems = model.itemFactors.count
     val expected = Map(
       0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
       1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)),
       2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f))
     )
-    checkRecommendations(topItems, expected, "item")
-  }
 
-  test("recommendForAllItems with k < num_users") {
-    val topUsers = getALSModel.recommendForAllItems(2)
-    assert(topUsers.count() == 4)
-    assert(topUsers.columns.contains("item"))
-
-    val expected = Map(
-      3 -> Array((0, 54f), (2, 51f)),
-      4 -> Array((0, 44f), (2, 30f)),
-      5 -> Array((2, 45f), (0, 42f)),
-      6 -> Array((0, 28f), (2, 18f))
-    )
-    checkRecommendations(topUsers, expected, "user")
+    Seq(2, 4, 6).foreach { k =>
+      val n = math.min(k, numItems).toInt
+      val expectedUpToN = expected.mapValues(_.slice(0, n))
+      val topItems = model.recommendForAllUsers(k)
+      assert(topItems.count() == numUsers)
+      assert(topItems.columns.contains("user"))
+      checkRecommendations(topItems, expectedUpToN, "item")
+    }
   }
 
-  test("recommendForAllItems with k = num_users") {
-    val topUsers = getALSModel.recommendForAllItems(3)
-    assert(topUsers.count() == 4)
-    assert(topUsers.columns.contains("item"))
-
+  test("recommendForAllItems with k <, = and > num_users") {
+    val model = getALSModel
+    val numUsers = model.userFactors.count
+    val numItems = model.itemFactors.count
     val expected = Map(
       3 -> Array((0, 54f), (2, 51f), (1, 39f)),
       4 -> Array((0, 44f), (2, 30f), (1, 26f)),
       5 -> Array((2, 45f), (0, 42f), (1, 33f)),
       6 -> Array((0, 28f), (2, 18f), (1, 16f))
     )
-    checkRecommendations(topUsers, expected, "user")
+
+    Seq(2, 3, 4).foreach { k =>
+      val n = math.min(k, numUsers).toInt
+      val expectedUpToN = expected.mapValues(_.slice(0, n))
+      val topUsers = getALSModel.recommendForAllItems(k)
+      assert(topUsers.count() == numItems)
+      assert(topUsers.columns.contains("item"))
+      checkRecommendations(topUsers, expectedUpToN, "user")
+    }
   }
 
   private def checkRecommendations(
-- 
GitLab