diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ec98b05e13b89311f75366a19d298024c31b15ca..8361406f87299d6b2e307d2ff8bac511bf140392 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,6 +24,9 @@ import scala.annotation.varargs import scala.collection.mutable import scala.collection.JavaConverters._ +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.ml.util.Identifiable @@ -80,6 +83,30 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali /** Creates a param pair with the given value (for Scala). */ def ->(value: T): ParamPair[T] = ParamPair(this, value) + /** Encodes a param value into JSON, which can be decoded by [[jsonDecode()]]. */ + def jsonEncode(value: T): String = { + value match { + case x: String => + compact(render(JString(x))) + case _ => + throw new NotImplementedError( + "The default jsonEncode only supports string. " + + s"${this.getClass.getName} must override jsonEncode for ${value.getClass.getName}.") + } + } + + /** Decodes a param value from JSON. */ + def jsonDecode(json: String): T = { + parse(json) match { + case JString(x) => + x.asInstanceOf[T] + case _ => + throw new NotImplementedError( + "The default jsonDecode only supports string. " + + s"${this.getClass.getName} must override jsonDecode to support its value type.") + } + } + override final def toString: String = s"${parent}__$name" override final def hashCode: Int = toString.## @@ -198,6 +225,46 @@ class DoubleParam(parent: String, name: String, doc: String, isValid: Double => /** Creates a param pair with the given value (for Java). */ override def w(value: Double): ParamPair[Double] = super.w(value) + + override def jsonEncode(value: Double): String = { + compact(render(DoubleParam.jValueEncode(value))) + } + + override def jsonDecode(json: String): Double = { + DoubleParam.jValueDecode(parse(json)) + } +} + +private[param] object DoubleParam { + /** Encodes a param value into JValue. */ + def jValueEncode(value: Double): JValue = { + value match { + case _ if value.isNaN => + JString("NaN") + case Double.NegativeInfinity => + JString("-Inf") + case Double.PositiveInfinity => + JString("Inf") + case _ => + JDouble(value) + } + } + + /** Decodes a param value from JValue. */ + def jValueDecode(jValue: JValue): Double = { + jValue match { + case JString("NaN") => + Double.NaN + case JString("-Inf") => + Double.NegativeInfinity + case JString("Inf") => + Double.PositiveInfinity + case JDouble(x) => + x + case _ => + throw new IllegalArgumentException(s"Cannot decode $jValue to Double.") + } + } } /** @@ -218,6 +285,15 @@ class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolea /** Creates a param pair with the given value (for Java). */ override def w(value: Int): ParamPair[Int] = super.w(value) + + override def jsonEncode(value: Int): String = { + compact(render(JInt(value))) + } + + override def jsonDecode(json: String): Int = { + implicit val formats = DefaultFormats + parse(json).extract[Int] + } } /** @@ -238,6 +314,47 @@ class FloatParam(parent: String, name: String, doc: String, isValid: Float => Bo /** Creates a param pair with the given value (for Java). */ override def w(value: Float): ParamPair[Float] = super.w(value) + + override def jsonEncode(value: Float): String = { + compact(render(FloatParam.jValueEncode(value))) + } + + override def jsonDecode(json: String): Float = { + FloatParam.jValueDecode(parse(json)) + } +} + +private object FloatParam { + + /** Encodes a param value into JValue. */ + def jValueEncode(value: Float): JValue = { + value match { + case _ if value.isNaN => + JString("NaN") + case Float.NegativeInfinity => + JString("-Inf") + case Float.PositiveInfinity => + JString("Inf") + case _ => + JDouble(value) + } + } + + /** Decodes a param value from JValue. */ + def jValueDecode(jValue: JValue): Float = { + jValue match { + case JString("NaN") => + Float.NaN + case JString("-Inf") => + Float.NegativeInfinity + case JString("Inf") => + Float.PositiveInfinity + case JDouble(x) => + x.toFloat + case _ => + throw new IllegalArgumentException(s"Cannot decode $jValue to Float.") + } + } } /** @@ -258,6 +375,15 @@ class LongParam(parent: String, name: String, doc: String, isValid: Long => Bool /** Creates a param pair with the given value (for Java). */ override def w(value: Long): ParamPair[Long] = super.w(value) + + override def jsonEncode(value: Long): String = { + compact(render(JInt(value))) + } + + override def jsonDecode(json: String): Long = { + implicit val formats = DefaultFormats + parse(json).extract[Long] + } } /** @@ -272,6 +398,15 @@ class BooleanParam(parent: String, name: String, doc: String) // No need for isV /** Creates a param pair with the given value (for Java). */ override def w(value: Boolean): ParamPair[Boolean] = super.w(value) + + override def jsonEncode(value: Boolean): String = { + compact(render(JBool(value))) + } + + override def jsonDecode(json: String): Boolean = { + implicit val formats = DefaultFormats + parse(json).extract[Boolean] + } } /** @@ -287,6 +422,16 @@ class StringArrayParam(parent: Params, name: String, doc: String, isValid: Array /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[String]): ParamPair[Array[String]] = w(value.asScala.toArray) + + override def jsonEncode(value: Array[String]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq)) + } + + override def jsonDecode(json: String): Array[String] = { + implicit val formats = DefaultFormats + parse(json).extract[Seq[String]].toArray + } } /** @@ -303,6 +448,20 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[java.lang.Double]): ParamPair[Array[Double]] = w(value.asScala.map(_.asInstanceOf[Double]).toArray) + + override def jsonEncode(value: Array[Double]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq.map(DoubleParam.jValueEncode))) + } + + override def jsonDecode(json: String): Array[Double] = { + parse(json) match { + case JArray(values) => + values.map(DoubleParam.jValueDecode).toArray + case _ => + throw new IllegalArgumentException(s"Cannot decode $json to Array[Double].") + } + } } /** @@ -319,6 +478,16 @@ class IntArrayParam(parent: Params, name: String, doc: String, isValid: Array[In /** Creates a param pair with a [[java.util.List]] of values (for Java and Python). */ def w(value: java.util.List[java.lang.Integer]): ParamPair[Array[Int]] = w(value.asScala.map(_.asInstanceOf[Int]).toArray) + + override def jsonEncode(value: Array[Int]): String = { + import org.json4s.JsonDSL._ + compact(render(value.toSeq)) + } + + override def jsonDecode(json: String): Array[Int] = { + implicit val formats = DefaultFormats + parse(json).extract[Seq[Int]].toArray + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a2ea279f5d5e4f7166c394a8e8c8dae85dd211b7..eeb03dba2f8259fd9237b400af7404a48c5c64a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -21,6 +21,120 @@ import org.apache.spark.SparkFunSuite class ParamsSuite extends SparkFunSuite { + test("json encode/decode") { + val dummy = new Params { + override def copy(extra: ParamMap): Params = defaultCopy(extra) + + override val uid: String = "dummy" + } + + { // BooleanParam + val param = new BooleanParam(dummy, "name", "doc") + for (value <- Seq(true, false)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // IntParam + val param = new IntParam(dummy, "name", "doc") + for (value <- Seq(Int.MinValue, -1, 0, 1, Int.MaxValue)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // LongParam + val param = new LongParam(dummy, "name", "doc") + for (value <- Seq(Long.MinValue, -1L, 0L, 1L, Long.MaxValue)) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // FloatParam + val param = new FloatParam(dummy, "name", "doc") + for (value <- Seq(Float.NaN, Float.NegativeInfinity, Float.MinValue, -1.0f, -0.5f, 0.0f, + Float.MinPositiveValue, 0.5f, 1.0f, Float.MaxValue, Float.PositiveInfinity)) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + if (value.isNaN) { + assert(decoded.isNaN) + } else { + assert(decoded === value) + } + } + } + + { // DoubleParam + val param = new DoubleParam(dummy, "name", "doc") + for (value <- Seq(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, -0.5, 0.0, + Double.MinPositiveValue, 0.5, 1.0, Double.MaxValue, Double.PositiveInfinity)) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + if (value.isNaN) { + assert(decoded.isNaN) + } else { + assert(decoded === value) + } + } + } + + { // StringParam + val param = new Param[String](dummy, "name", "doc") + // Currently we do not support null. + for (value <- Seq("", "1", "abc", "quote\"", "newline\n")) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // IntArrayParam + val param = new IntArrayParam(dummy, "name", "doc") + val values: Seq[Array[Int]] = Seq( + Array(), + Array(1), + Array(Int.MinValue, 0, Int.MaxValue)) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + + { // DoubleArrayParam + val param = new DoubleArrayParam(dummy, "name", "doc") + val values: Seq[Array[Double]] = Seq( + Array(), + Array(1.0), + Array(Double.NaN, Double.NegativeInfinity, Double.MinValue, -1.0, 0.0, + Double.MinPositiveValue, 1.0, Double.MaxValue, Double.PositiveInfinity)) + for (value <- values) { + val json = param.jsonEncode(value) + val decoded = param.jsonDecode(json) + assert(decoded.length === value.length) + decoded.zip(value).foreach { case (actual, expected) => + if (expected.isNaN) { + assert(actual.isNaN) + } else { + assert(actual === expected) + } + } + } + } + + { // StringArrayParam + val param = new StringArrayParam(dummy, "name", "doc") + val values: Seq[Array[String]] = Seq( + Array(), + Array(""), + Array("", "1", "abc", "quote\"", "newline\n")) + for (value <- values) { + val json = param.jsonEncode(value) + assert(param.jsonDecode(json) === value) + } + } + } + test("param") { val solver = new TestParams() val uid = solver.uid