Skip to content
Snippets Groups Projects
Commit b5fb1410 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-4604][MLLIB] make MatrixFactorizationModel public

User could construct an MF model directly. I added a note about the performance.

Author: Xiangrui Meng <meng@databricks.com>

Closes #3459 from mengxr/SPARK-4604 and squashes the following commits:

f64bcd3 [Xiangrui Meng] organize imports
ed08214 [Xiangrui Meng] check preconditions and unit tests
a624c12 [Xiangrui Meng] make MatrixFactorizationModel public
parent 4d95526a
No related branches found
No related tags found
No related merge requests found
......@@ -21,23 +21,45 @@ import java.lang.{Integer => JavaInteger}
import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext._
import org.apache.spark.Logging
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
* Model representing the result of matrix factorization.
*
* Note: If you create the model directly using constructor, please be aware that fast prediction
* requires cached user/product features and their associated partitioners.
*
* @param rank Rank for the features in this model.
* @param userFeatures RDD of tuples where each tuple represents the userId and
* the features computed for this user.
* @param productFeatures RDD of tuples where each tuple represents the productId
* and the features computed for this product.
*/
class MatrixFactorizationModel private[mllib] (
class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable {
val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
require(rank > 0)
validateFeatures("User", userFeatures)
validateFeatures("Product", productFeatures)
/** Validates factors and warns users if there are performance concerns. */
private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = {
require(features.first()._2.size == rank,
s"$name feature dimension does not match the rank $rank.")
if (features.partitioner.isEmpty) {
logWarning(s"$name factor does not have a partitioner. "
+ "Prediction on individual records could be slow.")
}
if (features.getStorageLevel == StorageLevel.NONE) {
logWarning(s"$name factor is not cached. Prediction could be slow.")
}
}
/** Predict the rating of one user for one product. */
def predict(user: Int, product: Int): Double = {
val userVector = new DoubleMatrix(userFeatures.lookup(user).head)
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.recommendation
import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
val rank = 2
var userFeatures: RDD[(Int, Array[Double])] = _
var prodFeatures: RDD[(Int, Array[Double])] = _
override def beforeAll(): Unit = {
super.beforeAll()
userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0))))
prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0))))
}
test("constructor") {
val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
assert(model.predict(0, 2) ~== 17.0 relTol 1e-14)
intercept[IllegalArgumentException] {
new MatrixFactorizationModel(1, userFeatures, prodFeatures)
}
val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0))))
intercept[IllegalArgumentException] {
new MatrixFactorizationModel(rank, userFeatures1, prodFeatures)
}
val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0))))
intercept[IllegalArgumentException] {
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment