Skip to content
Snippets Groups Projects
Commit c447c9d5 authored by Xiangrui Meng's avatar Xiangrui Meng Committed by Joseph K. Bradley
Browse files

[SPARK-11217][ML] save/load for non-meta estimators and transformers

This PR implements the default save/load for non-meta estimators and transformers using the JSON serialization of param values. The saved metadata includes:

* class name
* uid
* timestamp
* paramMap

The save/load interface is similar to DataFrames. We use the current active context by default, which should be sufficient for most use cases.

~~~scala
instance.save("path")
instance.write.context(sqlContext).overwrite().save("path")

Instance.load("path")
~~~

The param handling is different from the design doc. We didn't save default and user-set params separately, and when we load it back, all parameters are user-set. This does cause issues. But it also cause other issues if we modify the default params.

TODOs:

* [x] Java test
* [ ] a follow-up PR to implement default save/load for all non-meta estimators and transformers

cc jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #9454 from mengxr/SPARK-11217.
parent 3a652f69
No related branches found
No related tags found
No related merge requests found
...@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer ...@@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.ml.util._
import org.apache.spark.sql._ import org.apache.spark.sql._
import org.apache.spark.sql.functions._ import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.types.{DoubleType, StructType}
...@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} ...@@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
*/ */
@Experimental @Experimental
final class Binarizer(override val uid: String) final class Binarizer(override val uid: String)
extends Transformer with HasInputCol with HasOutputCol { extends Transformer with Writable with HasInputCol with HasOutputCol {
def this() = this(Identifiable.randomUID("binarizer")) def this() = this(Identifiable.randomUID("binarizer"))
...@@ -86,4 +86,11 @@ final class Binarizer(override val uid: String) ...@@ -86,4 +86,11 @@ final class Binarizer(override val uid: String)
} }
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
override def write: Writer = new DefaultParamsWriter(this)
}
object Binarizer extends Readable[Binarizer] {
override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
} }
...@@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable { ...@@ -592,7 +592,7 @@ trait Params extends Identifiable with Serializable {
/** /**
* Sets a parameter in the embedded param map. * Sets a parameter in the embedded param map.
*/ */
protected final def set[T](param: Param[T], value: T): this.type = { final def set[T](param: Param[T], value: T): this.type = {
set(param -> value) set(param -> value)
} }
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import java.io.IOException
import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{ParamPair, Params}
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
/**
* Trait for [[Writer]] and [[Reader]].
*/
private[util] sealed trait BaseReadWrite {
private var optionSQLContext: Option[SQLContext] = None
/**
* Sets the SQL context to use for saving/loading.
*/
@Since("1.6.0")
def context(sqlContext: SQLContext): this.type = {
optionSQLContext = Option(sqlContext)
this
}
/**
* Returns the user-specified SQL context or the default.
*/
protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
SQLContext.getOrCreate(SparkContext.getOrCreate())
}
}
/**
* Abstract class for utility classes that can save ML instances.
*/
@Experimental
@Since("1.6.0")
abstract class Writer extends BaseReadWrite {
protected var shouldOverwrite: Boolean = false
/**
* Saves the ML instances to the input path.
*/
@Since("1.6.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit
/**
* Overwrites if the output path already exists.
*/
@Since("1.6.0")
def overwrite(): this.type = {
shouldOverwrite = true
this
}
// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
}
/**
* Trait for classes that provide [[Writer]].
*/
@Since("1.6.0")
trait Writable {
/**
* Returns a [[Writer]] instance for this ML instance.
*/
@Since("1.6.0")
def write: Writer
/**
* Saves this ML instance to the input path, a shortcut of `write.save(path)`.
*/
@Since("1.6.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit = write.save(path)
}
/**
* Abstract class for utility classes that can load ML instances.
* @tparam T ML instance type
*/
@Experimental
@Since("1.6.0")
abstract class Reader[T] extends BaseReadWrite {
/**
* Loads the ML component from the input path.
*/
@Since("1.6.0")
def load(path: String): T
// override for Java compatibility
override def context(sqlContext: SQLContext): this.type = super.context(sqlContext)
}
/**
* Trait for objects that provide [[Reader]].
* @tparam T ML instance type
*/
@Experimental
@Since("1.6.0")
trait Readable[T] {
/**
* Returns a [[Reader]] instance for this class.
*/
@Since("1.6.0")
def read: Reader[T]
/**
* Reads an ML instance from the input path, a shortcut of `read.load(path)`.
*/
@Since("1.6.0")
def load(path: String): T = read.load(path)
}
/**
* Default [[Writer]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @param instance object to save
*/
private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging {
/**
* Saves the ML component to the input path.
*/
override def save(path: String): Unit = {
val sc = sqlContext.sparkContext
val hadoopConf = sc.hadoopConfiguration
val fs = FileSystem.get(hadoopConf)
val p = new Path(path)
if (fs.exists(p)) {
if (shouldOverwrite) {
logInfo(s"Path $path already exists. It will be overwritten.")
// TODO: Revert back to the original content if save is not successful.
fs.delete(p, true)
} else {
throw new IOException(
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
}
}
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
val jsonParams = params.map { case ParamPair(p, v) =>
p.name -> parse(p.jsonEncode(v))
}.toList
val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
}
}
/**
* Default [[Reader]] implementation for transformers and estimators that contain basic
* (json4s-serializable) params and no data. This will not handle more complex params or types with
* data (e.g., models with coefficients).
* @tparam T ML instance type
*/
private[ml] class DefaultParamsReader[T] extends Reader[T] {
/**
* Loads the ML component from the input path.
*/
override def load(path: String): T = {
implicit val format = DefaultFormats
val sc = sqlContext.sparkContext
val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr)
val cls = Utils.classForName((metadata \ "class").extract[String])
val uid = (metadata \ "uid").extract[String]
val instance = cls.getConstructor(classOf[String]).newInstance(uid).asInstanceOf[Params]
(metadata \ "paramMap") match {
case JObject(pairs) =>
pairs.foreach { case (paramName, jsonValue) =>
val param = instance.getParam(paramName)
val value = param.jsonDecode(compact(render(jsonValue)))
instance.set(param, value)
}
case _ =>
throw new IllegalArgumentException(s"Cannot recognize JSON metadata: $metadataStr.")
}
instance.asInstanceOf[T]
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util;
import java.io.File;
import java.io.IOException;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;
public class JavaDefaultReadWriteSuite {
JavaSparkContext jsc = null;
File tempDir = null;
@Before
public void setUp() {
jsc = new JavaSparkContext("local[2]", "JavaDefaultReadWriteSuite");
tempDir = Utils.createTempDir(
System.getProperty("java.io.tmpdir"), "JavaDefaultReadWriteSuite");
}
@After
public void tearDown() {
if (jsc != null) {
jsc.stop();
jsc = null;
}
Utils.deleteRecursively(tempDir);
}
@Test
public void testDefaultReadWrite() throws IOException {
String uid = "my_params";
MyParams instance = new MyParams(uid);
instance.set(instance.intParam(), 2);
String outputPath = new File(tempDir, uid).getPath();
instance.save(outputPath);
try {
instance.save(outputPath);
Assert.fail(
"Write without overwrite enabled should fail if the output directory already exists.");
} catch (IOException e) {
// expected
}
SQLContext sqlContext = new SQLContext(jsc);
instance.write().context(sqlContext).overwrite().save(outputPath);
MyParams newInstance = MyParams.load(outputPath);
Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
Assert.assertEquals("Params should be preserved.",
2, newInstance.getOrDefault(newInstance.intParam()));
}
}
...@@ -19,10 +19,11 @@ package org.apache.spark.ml.feature ...@@ -19,10 +19,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.{DataFrame, Row}
class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var data: Array[Double] = _ @transient var data: Array[Double] = _
...@@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -66,4 +67,12 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x === y, "The feature value is not correct after binarization.") assert(x === y, "The feature value is not correct after binarization.")
} }
} }
test("read/write") {
val binarizer = new Binarizer()
.setInputCol("feature")
.setOutputCol("binarized_feature")
.setThreshold(0.1)
testDefaultReadWrite(binarizer)
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import java.io.{File, IOException}
import org.scalatest.Suite
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLlibTestSparkContext
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
/**
* Checks "overwrite" option and params.
* @param instance ML instance to test saving/loading
* @tparam T ML instance type
*/
def testDefaultReadWrite[T <: Params with Writable](instance: T): Unit = {
val uid = instance.uid
val path = new File(tempDir, uid).getPath
instance.save(path)
intercept[IOException] {
instance.save(path)
}
instance.write.overwrite().save(path)
val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]]
val newInstance = loader.load(path)
assert(newInstance.uid === instance.uid)
instance.params.foreach { p =>
if (instance.isDefined(p)) {
(instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
case (Array(values), Array(newValues)) =>
assert(values === newValues, s"Values do not match on param ${p.name}.")
case (value, newValue) =>
assert(value === newValue, s"Values do not match on param ${p.name}.")
}
} else {
assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
}
}
val load = instance.getClass.getMethod("load", classOf[String])
val another = load.invoke(instance, path).asInstanceOf[T]
assert(another.uid === instance.uid)
}
}
class MyParams(override val uid: String) extends Params with Writable {
final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc")
final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc")
final val longParam: LongParam = new LongParam(this, "longParam", "doc")
final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc")
final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc")
final val doubleArrayParam: DoubleArrayParam =
new DoubleArrayParam(this, "doubleArrayParam", "doc")
final val stringArrayParam: StringArrayParam =
new StringArrayParam(this, "stringArrayParam", "doc")
setDefault(intParamWithDefault -> 0)
set(intParam -> 1)
set(floatParam -> 2.0f)
set(doubleParam -> 3.0)
set(longParam -> 4L)
set(stringParam -> "5")
set(intArrayParam -> Array(6, 7))
set(doubleArrayParam -> Array(8.0, 9.0))
set(stringArrayParam -> Array("10", "11"))
override def copy(extra: ParamMap): Params = defaultCopy(extra)
override def write: Writer = new DefaultParamsWriter(this)
}
object MyParams extends Readable[MyParams] {
override def read: Reader[MyParams] = new DefaultParamsReader[MyParams]
override def load(path: String): MyParams = read.load(path)
}
class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
test("default read/write") {
val myParams = new MyParams("my_params")
testDefaultReadWrite(myParams)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.util
import java.io.File
import org.scalatest.{BeforeAndAfterAll, Suite}
import org.apache.spark.util.Utils
/**
* Trait that creates a temporary directory before all tests and deletes it after all.
*/
trait TempDirectory extends BeforeAndAfterAll { self: Suite =>
private var _tempDir: File = _
/** Returns the temporary directory as a [[File]] instance. */
protected def tempDir: File = _tempDir
override def beforeAll(): Unit = {
super.beforeAll()
_tempDir = Utils.createTempDir(this.getClass.getName)
}
override def afterAll(): Unit = {
Utils.deleteRecursively(_tempDir)
super.afterAll()
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment