Skip to content
Snippets Groups Projects
Commit 1f2f723b authored by Yanbo Liang's avatar Yanbo Liang Committed by Xiangrui Meng
Browse files

[SPARK-5990] [MLLIB] Model import/export for IsotonicRegression

Model import/export for IsotonicRegression

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #5270 from yanboliang/spark-5990 and squashes the following commits:

872028d [Yanbo Liang] fix code style
f80ec1b [Yanbo Liang] address comments
49600cc [Yanbo Liang] address comments
429ff7d [Yanbo Liang] store each interval as a record
2b2f5a1 [Yanbo Liang] Model import/export for IsotonicRegression
parent ab9128fb
No related branches found
No related tags found
No related merge requests found
......@@ -23,9 +23,16 @@ import java.util.Arrays.binarySearch
import scala.collection.mutable.ArrayBuffer
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD}
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}
/**
* :: Experimental ::
......@@ -42,7 +49,7 @@ import org.apache.spark.rdd.RDD
class IsotonicRegressionModel (
val boundaries: Array[Double],
val predictions: Array[Double],
val isotonic: Boolean) extends Serializable {
val isotonic: Boolean) extends Serializable with Saveable {
private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse
......@@ -124,6 +131,75 @@ class IsotonicRegressionModel (
predictions(foundIndex)
}
}
override def save(sc: SparkContext, path: String): Unit = {
IsotonicRegressionModel.SaveLoadV1_0.save(sc, path, boundaries, predictions, isotonic)
}
override protected def formatVersion: String = "1.0"
}
object IsotonicRegressionModel extends Loader[IsotonicRegressionModel] {
import org.apache.spark.mllib.util.Loader._
private object SaveLoadV1_0 {
def thisFormatVersion: String = "1.0"
/** Hard-code class name string in case it changes in the future */
def thisClassName: String = "org.apache.spark.mllib.regression.IsotonicRegressionModel"
/** Model data for model import/export */
case class Data(boundary: Double, prediction: Double)
def save(
sc: SparkContext,
path: String,
boundaries: Array[Double],
predictions: Array[Double],
isotonic: Boolean): Unit = {
val sqlContext = new SQLContext(sc)
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
("isotonic" -> isotonic)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
sqlContext.createDataFrame(
boundaries.toSeq.zip(predictions).map { case (b, p) => Data(b, p) }
).saveAsParquetFile(dataPath(path))
}
def load(sc: SparkContext, path: String): (Array[Double], Array[Double]) = {
val sqlContext = new SQLContext(sc)
val dataRDD = sqlContext.parquetFile(dataPath(path))
checkSchema[Data](dataRDD.schema)
val dataArray = dataRDD.select("boundary", "prediction").collect()
val (boundaries, predictions) = dataArray.map { x =>
(x.getDouble(0), x.getDouble(1))
}.toList.sortBy(_._1).unzip
(boundaries.toArray, predictions.toArray)
}
}
override def load(sc: SparkContext, path: String): IsotonicRegressionModel = {
implicit val formats = DefaultFormats
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val isotonic = (metadata \ "isotonic").extract[Boolean]
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
val (boundaries, predictions) = SaveLoadV1_0.load(sc, path)
new IsotonicRegressionModel(boundaries, predictions, isotonic)
case _ => throw new Exception(
s"IsotonicRegressionModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)"
)
}
}
}
/**
......
......@@ -21,6 +21,7 @@ import org.scalatest.{Matchers, FunSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
......@@ -73,6 +74,26 @@ class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with M
assert(model.isotonic)
}
test("model save/load") {
val boundaries = Array(0.0, 1.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0)
val predictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)
val model = new IsotonicRegressionModel(boundaries, predictions, true)
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
// Save model, load it back, and compare.
try {
model.save(sc, path)
val sameModel = IsotonicRegressionModel.load(sc, path)
assert(model.boundaries === sameModel.boundaries)
assert(model.predictions === sameModel.predictions)
assert(model.isotonic === model.isotonic)
} finally {
Utils.deleteRecursively(tempDir)
}
}
test("isotonic regression with size 0") {
val model = runIsotonicRegression(Seq(), true)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment