Skip to content
Snippets Groups Projects
Commit 2b574f52 authored by Xiangrui Meng's avatar Xiangrui Meng
Browse files

[SPARK-7402] [ML] JSON SerDe for standard param types

This PR implements the JSON SerDe for the following param types: `Boolean`, `Int`, `Long`, `Float`, `Double`, `String`, `Array[Int]`, `Array[Double]`, and `Array[String]`. The implementation of `Float`, `Double`, and `Array[Double]` are specialized to handle `NaN` and `Inf`s. This will be used in pipeline persistence. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #9090 from mengxr/SPARK-7402.
parent c75f058b
No related branches found
No related tags found
No related merge requests found
...@@ -24,6 +24,9 @@ import scala.annotation.varargs ...@@ -24,6 +24,9 @@ import scala.annotation.varargs
import scala.collection.mutable import scala.collection.mutable
import scala.collection.JavaConverters._ import scala.collection.JavaConverters._
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.ml.util.Identifiable import org.apache.spark.ml.util.Identifiable
...@@ -80,6 +83,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali ...@@ -80,6 +83,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
/** Creates a param pair with the given value (for Scala). */ /** Creates a param pair with the given value (for Scala). */
def ->(value: T): ParamPair[T] = ParamPair(this, value) def ->(value: T): ParamPair[T] = ParamPair(this, value)
/** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */
def jsonEncode(value: T): String = {
value match {
case x: String =>
compact(render(JString(x)))
case _ =>
throw new NotImplementedError(
"The default jsonEncode only supports string. " +
s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.")
}
}
/** Decodes a param value from JSON. */
def jsonDecode(json: String): T = {
parse(json) match {
case JString(x) =>
x.asInstanceOf[T]
case _ =>
throw new NotImplementedError(
"The default jsonDecode only supports string. " +
s"${this.getClass.getName} must override jsonDecode to support its value type.")
}
}
override final def toString: String = s"${parent}__$name" override final def toString: String = s"${parent}__$name"
override final def hashCode: Int = toString.## override final def hashCode: Int = toString.##
...@@ -198,6 +225,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => ...@@ -198,6 +225,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double =>
/** Creates a param pair with the given value (for Java). */ /** Creates a param pair with the given value (for Java). */
override def w(value: Double): ParamPair[Double] = super.w(value) override def w(value: Double): ParamPair[Double] = super.w(value)
override def jsonEncode(value: Double): String = {
compact(render(DoubleParam.jValueEncode(value)))
}
override def jsonDecode(json: String): Double = {
DoubleParam.jValueDecode(parse(json))
}
}
private[param] object DoubleParam {
/** Encodes a param value into JValue. */
def jValueEncode(value: Double): JValue = {
value match {
case _ if value.isNaN =>
JString("NaN")
case Double.NegativeInfinity =>
JString("-Inf")
case Double.PositiveInfinity =>
JString("Inf")
case _ =>
JDouble(value)
}
}
/** Decodes a param value from JValue. */
def jValueDecode(jValue: JValue): Double = {
jValue match {
case JString("NaN") =>
Double.NaN
case JString("-Inf") =>
Double.NegativeInfinity
case JString("Inf") =>
Double.PositiveInfinity
case JDouble(x) =>
x
case _ =>
throw new IllegalArgumentException(s"Cannot decode $jValue to Double.")
}
}
} }
/** /**
...@@ -218,6 +285,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea ...@@ -218,6 +285,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea
/** Creates a param pair with the given value (for Java). */ /** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value) override def w(value: Int): ParamPair[Int] = super.w(value)
override def jsonEncode(value: Int): String = {
compact(render(JInt(value)))
}
override def jsonDecode(json: String): Int = {
implicit val formats = DefaultFormats
parse(json).extract[Int]
}
} }
/** /**
...@@ -238,6 +314,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo ...@@ -238,6 +314,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo
/** Creates a param pair with the given value (for Java). */ /** Creates a param pair with the given value (for Java). */
override def w(value: Float): ParamPair[Float] = super.w(value) override def w(value: Float): ParamPair[Float] = super.w(value)
override def jsonEncode(value: Float): String = {
compact(render(FloatParam.jValueEncode(value)))
}
override def jsonDecode(json: String): Float = {
FloatParam.jValueDecode(parse(json))
}
}
private object FloatParam {
/** Encodes a param value into JValue. */
def jValueEncode(value: Float): JValue = {
value match {
case _ if value.isNaN =>
JString("NaN")
case Float.NegativeInfinity =>
JString("-Inf")
case Float.PositiveInfinity =>
JString("Inf")
case _ =>
JDouble(value)
}
}
/** Decodes a param value from JValue. */
def jValueDecode(jValue: JValue): Float = {
jValue match {
case JString("NaN") =>
Float.NaN
case JString("-Inf") =>
Float.NegativeInfinity
case JString("Inf") =>
Float.PositiveInfinity
case JDouble(x) =>
x.toFloat
case _ =>
throw new IllegalArgumentException(s"Cannot decode $jValue to Float.")
}
}
} }
/** /**
...@@ -258,6 +375,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool ...@@ -258,6 +375,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool
/** Creates a param pair with the given value (for Java). */ /** Creates a param pair with the given value (for Java). */
override def w(value: Long): ParamPair[Long] = super.w(value) override def w(value: Long): ParamPair[Long] = super.w(value)
override def jsonEncode(value: Long): String = {
compact(render(JInt(value)))
}
override def jsonDecode(json: String): Long = {
implicit val formats = DefaultFormats
parse(json).extract[Long]
}
} }
/** /**
...@@ -272,6 +398,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV ...@@ -272,6 +398,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV
/** Creates a param pair with the given value (for Java). */ /** Creates a param pair with the given value (for Java). */
override def w(value: Boolean): ParamPair[Boolean] = super.w(value) override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
override def jsonEncode(value: Boolean): String = {
compact(render(JBool(value)))
}
override def jsonDecode(json: String): Boolean = {
implicit val formats = DefaultFormats
parse(json).extract[Boolean]
}
} }
/** /**
...@@ -287,6 +422,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array ...@@ -287,6 +422,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray)
override def jsonEncode(value: Array[String]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq))
}
override def jsonDecode(json: String): Array[String] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[String]].toArray
}
} }
/** /**
...@@ -303,6 +448,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array ...@@ -303,6 +448,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] =
w(value.asScala.map(_.asInstanceOf[Double]).toArray) w(value.asScala.map(_.asInstanceOf[Double]).toArray)
override def jsonEncode(value: Array[Double]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq.map(DoubleParam.jValueEncode)))
}
override def jsonDecode(json: String): Array[Double] = {
parse(json) match {
case JArray(values) =>
values.map(DoubleParam.jValueDecode).toArray
case _ =>
throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].")
}
}
} }
/** /**
...@@ -319,6 +478,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In ...@@ -319,6 +478,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In
/** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */
def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] = def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] =
w(value.asScala.map(_.asInstanceOf[Int]).toArray) w(value.asScala.map(_.asInstanceOf[Int]).toArray)
override def jsonEncode(value: Array[Int]): String = {
import org.json4s.JsonDSL._
compact(render(value.toSeq))
}
override def jsonDecode(json: String): Array[Int] = {
implicit val formats = DefaultFormats
parse(json).extract[Seq[Int]].toArray
}
} }
/** /**
......
...@@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite ...@@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite
class ParamsSuite extends SparkFunSuite { class ParamsSuite extends SparkFunSuite {
test("json encode/decode") {
val dummy = new Params {
override def copy(extra: ParamMap): Params = defaultCopy(extra)
override val uid: String = "dummy"
}
{ // BooleanParam
val param = new BooleanParam(dummy, "name", "doc")
for (value <- Seq(true, false)) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}
{ // IntParam
val param = new IntParam(dummy, "name", "doc")
for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}
{ // LongParam
val param = new LongParam(dummy, "name", "doc")
for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}
{ // FloatParam
val param = new FloatParam(dummy, "name", "doc")
for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f,
Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) {
val json = param.jsonEncode(value)
val decoded = param.jsonDecode(json)
if (value.isNaN) {
assert(decoded.isNaN)
} else {
assert(decoded === value)
}
}
}
{ // DoubleParam
val param = new DoubleParam(dummy, "name", "doc")
for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0,
Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) {
val json = param.jsonEncode(value)
val decoded = param.jsonDecode(json)
if (value.isNaN) {
assert(decoded.isNaN)
} else {
assert(decoded === value)
}
}
}
{ // StringParam
val param = new Param[String](dummy, "name", "doc")
// Currently we do not support null.
for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}
{ // IntArrayParam
val param = new IntArrayParam(dummy, "name", "doc")
val values: Seq[Array[Int]] = Seq(
Array(),
Array(1),
Array(Int.MinValue, 0, Int.MaxValue))
for (value <- values) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}
{ // DoubleArrayParam
val param = new DoubleArrayParam(dummy, "name", "doc")
val values: Seq[Array[Double]] = Seq(
Array(),
Array(1.0),
Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0,
Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity))
for (value <- values) {
val json = param.jsonEncode(value)
val decoded = param.jsonDecode(json)
assert(decoded.length === value.length)
decoded.zip(value).foreach { case (actual, expected) =>
if (expected.isNaN) {
assert(actual.isNaN)
} else {
assert(actual === expected)
}
}
}
}
{ // StringArrayParam
val param = new StringArrayParam(dummy, "name", "doc")
val values: Seq[Array[String]] = Seq(
Array(),
Array(""),
Array("", "1", "abc", "quote\"", "newline\n"))
for (value <- values) {
val json = param.jsonEncode(value)
assert(param.jsonDecode(json) === value)
}
}
}
test("param") { test("param") {
val solver = new TestParams() val solver = new TestParams()
val uid = solver.uid val uid = solver.uid
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment