Skip to content
Snippets Groups Projects
Commit 2acdf10b authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel

Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata.

CC: mengxr yanboliang

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9786 from jkbradley/als-io.
parent 045a4f04
No related branches found
No related tags found
No related merge requests found
......@@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s.{DefaultFormats, JValue}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
......@@ -182,7 +185,7 @@ class ALSModel private[ml] (
val rank: Int,
@transient val userFactors: DataFrame,
@transient val itemFactors: DataFrame)
extends Model[ALSModel] with ALSModelParams {
extends Model[ALSModel] with ALSModelParams with Writable {
/** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
......@@ -220,8 +223,60 @@ class ALSModel private[ml] (
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
copyValues(copied, extra).setParent(parent)
}
@Since("1.6.0")
override def write: Writer = new ALSModel.ALSModelWriter(this)
}
@Since("1.6.0")
object ALSModel extends Readable[ALSModel] {
@Since("1.6.0")
override def read: Reader[ALSModel] = new ALSModelReader
@Since("1.6.0")
override def load(path: String): ALSModel = read.load(path)
private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer {
override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val userPath = new Path(path, "userFactors").toString
instance.userFactors.write.format("parquet").save(userPath)
val itemPath = new Path(path, "itemFactors").toString
instance.itemFactors.write.format("parquet").save(itemPath)
}
}
private[recommendation] class ALSModelReader extends Reader[ALSModel] {
/** Checked against metadata when loading model */
private val className = "org.apache.spark.ml.recommendation.ALSModel"
override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
implicit val format = DefaultFormats
val rank: Int = metadata.extraMetadata match {
case Some(m: JValue) =>
(m \ "rank").extract[Int]
case None =>
throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
s" ${metadata.metadataStr}")
}
val userPath = new Path(path, "userFactors").toString
val userFactors = sqlContext.read.format("parquet").load(userPath)
val itemPath = new Path(path, "itemFactors").toString
val itemFactors = sqlContext.read.format("parquet").load(itemPath)
val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
}
}
/**
* :: Experimental ::
......@@ -254,7 +309,7 @@ class ALSModel private[ml] (
* preferences rather than explicit ratings given to items.
*/
@Experimental
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable {
import org.apache.spark.ml.recommendation.ALS.Rating
......@@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
}
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
@Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}
/**
* :: DeveloperApi ::
* An implementation of ALS that supports generic ID types, specialized for Int and Long. This is
......@@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
* than 2 billion.
*/
@DeveloperApi
object ALS extends Logging {
object ALS extends Readable[ALS] with Logging {
/**
* :: DeveloperApi ::
......@@ -356,6 +415,12 @@ object ALS extends Logging {
@DeveloperApi
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
@Since("1.6.0")
override def read: Reader[ALS] = new DefaultParamsReader[ALS]
@Since("1.6.0")
override def load(path: String): ALS = read.load(path)
/** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable {
/** Solves a least squares problem with regularization (possibly with other constraints). */
......
......@@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
def saveMetadata(
instance: Params,
path: String,
sc: SparkContext,
extraMetadata: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
......@@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter {
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
("paramMap" -> jsonParams) ~
("extraMetadata" -> extraMetadata)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
......@@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
* @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
* @param metadataStr Full metadata file String (for debugging)
*/
case class Metadata(
......@@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
extraMetadata: Option[JValue],
metadataStr: String)
/**
......@@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader {
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
Metadata(className, uid, timestamp, sparkVersion, params, metadataStr)
Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
}
/**
......
......@@ -17,7 +17,6 @@
package org.apache.spark.ml.recommendation
import java.io.File
import java.util.Random
import scala.collection.mutable
......@@ -26,28 +25,26 @@ import scala.language.existentials
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.util.Utils
import org.apache.spark.sql.{DataFrame, Row}
class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
private var tempDir: File = _
class ALSSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
override def beforeAll(): Unit = {
super.beforeAll()
tempDir = Utils.createTempDir()
sc.setCheckpointDir(tempDir.getAbsolutePath)
}
override def afterAll(): Unit = {
Utils.deleteRecursively(tempDir)
super.afterAll()
}
......@@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5))
var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)]
var i = 0
while (i < compressed.srcIds.size) {
while (i < compressed.srcIds.length) {
var j = compressed.dstPtrs(i)
while (j < compressed.dstPtrs(i + 1)) {
val dstEncodedIndex = compressed.dstEncodedIndices(j)
......@@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2,
implicitPrefs = true, seed = 0)
}
test("read/write") {
import ALSSuite._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
val als = new ALS()
allEstimatorParamSettings.foreach { case (p, v) =>
als.set(als.getParam(p), v)
}
val sqlContext = this.sqlContext
import sqlContext.implicits._
val model = als.fit(ratings.toDF())
// Test Estimator save/load
val als2 = testDefaultReadWrite(als)
allEstimatorParamSettings.foreach { case (p, v) =>
val param = als.getParam(p)
assert(als.get(param).get === als2.get(param).get)
}
// Test Model save/load
val model2 = testDefaultReadWrite(model)
allModelParamSettings.foreach { case (p, v) =>
val param = model.getParam(p)
assert(model.get(param).get === model2.get(param).get)
}
assert(model.rank === model2.rank)
def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
df.select("id", "features").collect().map { case r =>
(r.getInt(0), r.getAs[Array[Float]](1))
}.toSet
}
assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
}
}
object ALSSuite {
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allModelParamSettings: Map[String, Any] = Map(
"predictionCol" -> "myPredictionCol"
)
/**
* Mapping from all Params to valid settings which differ from the defaults.
* This is useful for tests which need to exercise all Params, such as save/load.
* This excludes input columns to simplify some tests.
*/
val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map(
"maxIter" -> 1,
"rank" -> 1,
"regParam" -> 0.01,
"numUserBlocks" -> 2,
"numItemBlocks" -> 2,
"implicitPrefs" -> true,
"alpha" -> 0.9,
"nonnegative" -> true,
"checkpointInterval" -> 20
)
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment