Skip to content
Snippets Groups Projects
Commit 18350a57 authored by Joseph K. Bradley's avatar Joseph K. Bradley Committed by Xiangrui Meng
Browse files

[SPARK-11618][ML] Minor refactoring of basic ML import/export

Refactoring
* separated overwrite and param save logic in DefaultParamsWriter
* added sparkVersion to DefaultParamsWriter

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #9587 from jkbradley/logreg-io.
parent f14e9511
No related branches found
No related tags found
No related merge requests found
...@@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite { ...@@ -51,6 +51,9 @@ private[util] sealed trait BaseReadWrite {
protected final def sqlContext: SQLContext = optionSQLContext.getOrElse { protected final def sqlContext: SQLContext = optionSQLContext.getOrElse {
SQLContext.getOrCreate(SparkContext.getOrCreate()) SQLContext.getOrCreate(SparkContext.getOrCreate())
} }
/** Returns the [[SparkContext]] underlying [[sqlContext]] */
protected final def sc: SparkContext = sqlContext.sparkContext
} }
/** /**
...@@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite { ...@@ -58,7 +61,7 @@ private[util] sealed trait BaseReadWrite {
*/ */
@Experimental @Experimental
@Since("1.6.0") @Since("1.6.0")
abstract class Writer extends BaseReadWrite { abstract class Writer extends BaseReadWrite with Logging {
protected var shouldOverwrite: Boolean = false protected var shouldOverwrite: Boolean = false
...@@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite { ...@@ -67,7 +70,29 @@ abstract class Writer extends BaseReadWrite {
*/ */
@Since("1.6.0") @Since("1.6.0")
@throws[IOException]("If the input path already exists but overwrite is not enabled.") @throws[IOException]("If the input path already exists but overwrite is not enabled.")
def save(path: String): Unit def save(path: String): Unit = {
val hadoopConf = sc.hadoopConfiguration
val fs = FileSystem.get(hadoopConf)
val p = new Path(path)
if (fs.exists(p)) {
if (shouldOverwrite) {
logInfo(s"Path $path already exists. It will be overwritten.")
// TODO: Revert back to the original content if save is not successful.
fs.delete(p, true)
} else {
throw new IOException(
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
}
}
saveImpl(path)
}
/**
* [[save()]] handles overwriting and then calls this method. Subclasses should override this
* method to implement the actual saving of the instance.
*/
@Since("1.6.0")
protected def saveImpl(path: String): Unit
/** /**
* Overwrites if the output path already exists. * Overwrites if the output path already exists.
...@@ -147,28 +172,9 @@ trait Readable[T] { ...@@ -147,28 +172,9 @@ trait Readable[T] {
* data (e.g., models with coefficients). * data (e.g., models with coefficients).
* @param instance object to save * @param instance object to save
*/ */
private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logging { private[ml] class DefaultParamsWriter(instance: Params) extends Writer {
/**
* Saves the ML component to the input path.
*/
override def save(path: String): Unit = {
val sc = sqlContext.sparkContext
val hadoopConf = sc.hadoopConfiguration
val fs = FileSystem.get(hadoopConf)
val p = new Path(path)
if (fs.exists(p)) {
if (shouldOverwrite) {
logInfo(s"Path $path already exists. It will be overwritten.")
// TODO: Revert back to the original content if save is not successful.
fs.delete(p, true)
} else {
throw new IOException(
s"Path $path already exists. Please use write.overwrite().save(path) to overwrite it.")
}
}
override protected def saveImpl(path: String): Unit = {
val uid = instance.uid val uid = instance.uid
val cls = instance.getClass.getName val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
...@@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg ...@@ -177,6 +183,7 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
}.toList }.toList
val metadata = ("class" -> cls) ~ val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~ ("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~ ("uid" -> uid) ~
("paramMap" -> jsonParams) ("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString val metadataPath = new Path(path, "metadata").toString
...@@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg ...@@ -193,12 +200,8 @@ private[ml] class DefaultParamsWriter(instance: Params) extends Writer with Logg
*/ */
private[ml] class DefaultParamsReader[T] extends Reader[T] { private[ml] class DefaultParamsReader[T] extends Reader[T] {
/**
* Loads the ML component from the input path.
*/
override def load(path: String): T = { override def load(path: String): T = {
implicit val format = DefaultFormats implicit val format = DefaultFormats
val sc = sqlContext.sparkContext
val metadataPath = new Path(path, "metadata").toString val metadataPath = new Path(path, "metadata").toString
val metadataStr = sc.textFile(metadataPath, 1).first() val metadataStr = sc.textFile(metadataPath, 1).first()
val metadata = parse(metadataStr) val metadata = parse(metadataStr)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment