Skip to content
Snippets Groups Projects
Commit db951378 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7922] [MLLIB] use DataFrames for user/item factors in ALSModel

Expose user/item factors in DataFrames. This is to be more consistent with the pipeline API. It also helps maintain consistent APIs across languages. This PR also removed fitting params from `ALSModel`.

coderxiang

Author: Xiangrui Meng <meng@databricks.com>

Closes #6468 from mengxr/SPARK-7922 and squashes the following commits:

7bfb1d5 [Xiangrui Meng] update ALSModel in PySpark
1ba5607 [Xiangrui Meng] use DataFrames for user/item factors in ALS
parent cd3d9a5c
No related branches found
No related tags found
No related merge requests found
...@@ -35,21 +35,46 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} ...@@ -35,21 +35,46 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.util.random.XORShiftRandom
/**
* Common params for ALS and ALSModel.
*/
private[recommendation] trait ALSModelParams extends Params with HasPredictionCol {
/**
* Param for the column name for user ids.
* Default: "user"
* @group param
*/
val userCol = new Param[String](this, "userCol", "column name for user ids")
/** @group getParam */
def getUserCol: String = $(userCol)
/**
* Param for the column name for item ids.
* Default: "item"
* @group param
*/
val itemCol = new Param[String](this, "itemCol", "column name for item ids")
/** @group getParam */
def getItemCol: String = $(itemCol)
}
/** /**
* Common params for ALS. * Common params for ALS.
*/ */
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter with HasRegParam
with HasPredictionCol with HasCheckpointInterval with HasSeed { with HasPredictionCol with HasCheckpointInterval with HasSeed {
/** /**
...@@ -105,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR ...@@ -105,26 +130,6 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
/** @group getParam */ /** @group getParam */
def getAlpha: Double = $(alpha) def getAlpha: Double = $(alpha)
/**
* Param for the column name for user ids.
* Default: "user"
* @group param
*/
val userCol = new Param[String](this, "userCol", "column name for user ids")
/** @group getParam */
def getUserCol: String = $(userCol)
/**
* Param for the column name for item ids.
* Default: "item"
* @group param
*/
val itemCol = new Param[String](this, "itemCol", "column name for item ids")
/** @group getParam */
def getItemCol: String = $(itemCol)
/** /**
* Param for the column name for ratings. * Param for the column name for ratings.
* Default: "rating" * Default: "rating"
...@@ -156,55 +161,60 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR ...@@ -156,55 +161,60 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
* @return output schema * @return output schema
*/ */
protected def validateAndTransformSchema(schema: StructType): StructType = { protected def validateAndTransformSchema(schema: StructType): StructType = {
require(schema($(userCol)).dataType == IntegerType) SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
require(schema($(itemCol)).dataType== IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType val ratingType = schema($(ratingCol)).dataType
require(ratingType == FloatType || ratingType == DoubleType) require(ratingType == FloatType || ratingType == DoubleType)
val predictionColName = $(predictionCol) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
require(!schema.fieldNames.contains(predictionColName),
s"Prediction column $predictionColName already exists.")
val newFields = schema.fields :+ StructField($(predictionCol), FloatType, nullable = false)
StructType(newFields)
} }
} }
/** /**
* :: Experimental :: * :: Experimental ::
* Model fitted by ALS. * Model fitted by ALS.
*
* @param rank rank of the matrix factorization model
* @param userFactors a DataFrame that stores user factors in two columns: `id` and `features`
* @param itemFactors a DataFrame that stores item factors in two columns: `id` and `features`
*/ */
@Experimental @Experimental
class ALSModel private[ml] ( class ALSModel private[ml] (
override val uid: String, override val uid: String,
k: Int, val rank: Int,
userFactors: RDD[(Int, Array[Float])], @transient val userFactors: DataFrame,
itemFactors: RDD[(Int, Array[Float])]) @transient val itemFactors: DataFrame)
extends Model[ALSModel] with ALSParams { extends Model[ALSModel] with ALSModelParams {
/** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
/** @group setParam */
def setItemCol(value: String): this.type = set(itemCol, value)
/** @group setParam */ /** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value)
override def transform(dataset: DataFrame): DataFrame = { override def transform(dataset: DataFrame): DataFrame = {
import dataset.sqlContext.implicits._
val users = userFactors.toDF("id", "features")
val items = itemFactors.toDF("id", "features")
// Register a UDF for DataFrame, and then // Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF. // create a new column named map(predictionCol) by running the predict UDF.
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
if (userFeatures != null && itemFeatures != null) { if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
} else { } else {
Float.NaN Float.NaN
} }
} }
dataset dataset
.join(users, dataset($(userCol)) === users("id"), "left") .join(userFactors, dataset($(userCol)) === userFactors("id"), "left")
.join(items, dataset($(itemCol)) === items("id"), "left") .join(itemFactors, dataset($(itemCol)) === itemFactors("id"), "left")
.select(dataset("*"), predict(users("features"), items("features")).as($(predictionCol))) .select(dataset("*"),
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
} }
override def transformSchema(schema: StructType): StructType = { override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema) SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
} }
} }
...@@ -299,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { ...@@ -299,6 +309,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
} }
override def fit(dataset: DataFrame): ALSModel = { override def fit(dataset: DataFrame): ALSModel = {
import dataset.sqlContext.implicits._
val ratings = dataset val ratings = dataset
.select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
col($(ratingCol)).cast(FloatType)) col($(ratingCol)).cast(FloatType))
...@@ -310,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams { ...@@ -310,7 +321,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative), alpha = $(alpha), nonnegative = $(nonnegative),
checkpointInterval = $(checkpointInterval), seed = $(seed)) checkpointInterval = $(checkpointInterval), seed = $(seed))
val model = new ALSModel(uid, $(rank), userFactors, itemFactors).setParent(this) val userDF = userFactors.toDF("id", "features")
val itemDF = itemFactors.toDF("id", "features")
val model = new ALSModel(uid, $(rank), userDF, itemDF).setParent(this)
copyValues(model) copyValues(model)
} }
......
...@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha ...@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
indicated user preferences rather than explicit ratings given to indicated user preferences rather than explicit ratings given to
items. items.
>>> df = sqlContext.createDataFrame(
... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
... ["user", "item", "rating"])
>>> als = ALS(rank=10, maxIter=5) >>> als = ALS(rank=10, maxIter=5)
>>> model = als.fit(df) >>> model = als.fit(df)
>>> model.rank
10
>>> model.userFactors.orderBy("id").collect()
[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
>>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"]) >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0]) >>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
>>> predictions[0] >>> predictions[0]
...@@ -260,6 +267,27 @@ class ALSModel(JavaModel): ...@@ -260,6 +267,27 @@ class ALSModel(JavaModel):
Model fitted by ALS. Model fitted by ALS.
""" """
@property
def rank(self):
"""rank of the matrix factorization model"""
return self._call_java("rank")
@property
def userFactors(self):
"""
a DataFrame that stores user factors in two columns: `id` and
`features`
"""
return self._call_java("userFactors")
@property
def itemFactors(self):
"""
a DataFrame that stores item factors in two columns: `id` and
`features`
"""
return self._call_java("itemFactors")
if __name__ == "__main__": if __name__ == "__main__":
import doctest import doctest
...@@ -272,8 +300,6 @@ if __name__ == "__main__": ...@@ -272,8 +300,6 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc) sqlContext = SQLContext(sc)
globs['sc'] = sc globs['sc'] = sc
globs['sqlContext'] = sqlContext globs['sqlContext'] = sqlContext
globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
(2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop() sc.stop()
if failure_count: if failure_count:
......
...@@ -27,7 +27,7 @@ from py4j.java_collections import ListConverter, JavaArray, JavaList ...@@ -27,7 +27,7 @@ from py4j.java_collections import ListConverter, JavaArray, JavaList
from pyspark import RDD, SparkContext from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql import DataFrame, SQLContext
# Hack for support float('inf') in Py4j # Hack for support float('inf') in Py4j
_old_smart_decode = py4j.protocol.smart_decode _old_smart_decode = py4j.protocol.smart_decode
...@@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"): ...@@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"):
jrdd = sc._jvm.SerDe.javaToPython(r) jrdd = sc._jvm.SerDe.javaToPython(r)
return RDD(jrdd, sc) return RDD(jrdd, sc)
if clsName == 'DataFrame':
return DataFrame(r, SQLContext(sc))
if clsName in _picklable_classes: if clsName in _picklable_classes:
r = sc._jvm.SerDe.dumps(r) r = sc._jvm.SerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)): elif isinstance(r, (JavaArray, JavaList)):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment