Skip to content
Snippets Groups Projects
Commit c4b1108c authored by Sandy Ryza's avatar Sandy Ryza Committed by Josh Rosen
Browse files

SPARK-4687. Add a recursive option to the addFile API

This adds a recursive option to the addFile API to satisfy Hive's needs.  It only allows specifying HDFS dirs that will be copied down on every executor.

There are a couple outstanding questions.
* Should we allow specifying local dirs as well?  The best way to do this would probably be to archive them.  The drawback is that it would require a fair bit of code that I don't know of any current use cases for.
* The addFiles implementation has a caching component that I don't entirely understand.  What events are we caching between?  AFAICT it's users calling addFile on the same file in the same app at different times?  Do we want/need to add something similar for addDirectory.
*  The addFiles implementation will check to see if an added file already exists and has the same contents.  I imagine we want the same behavior, so planning to add this unless people think otherwise.

I plan to add some tests if people are OK with the approach.

Author: Sandy Ryza <sandy@cloudera.com>

Closes #3670 from sryza/sandy-spark-4687 and squashes the following commits:

f9fc77f [Sandy Ryza] Josh's comments
70cd24d [Sandy Ryza] Add another test
13da824 [Sandy Ryza] Revert executor changes
38bf94d [Sandy Ryza] Marcelo's comments
ca83849 [Sandy Ryza] Add addFile test
1941be3 [Sandy Ryza] Fix test and avoid HTTP server in local mode
31f15a9 [Sandy Ryza] Use cache recursively and fix some compile errors
0239c3d [Sandy Ryza] Change addDirectory to addFile with recursive
46fe70a [Sandy Ryza] SPARK-4687. Add a addDirectory API
parent 6580929f
No related branches found
No related tags found
No related merge requests found
...@@ -25,29 +25,37 @@ import java.net.URI ...@@ -25,29 +25,37 @@ import java.net.URI
import java.util.{Arrays, Properties, UUID} import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.atomic.AtomicInteger
import java.util.UUID.randomUUID import java.util.UUID.randomUUID
import scala.collection.{Map, Set} import scala.collection.{Map, Set}
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
import scala.collection.generic.Growable import scala.collection.generic.Growable
import scala.collection.mutable.HashMap import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag} import scala.reflect.{ClassTag, classTag}
import akka.actor.Props
import org.apache.hadoop.conf.Configuration import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat} FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat,
TextInputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary import org.apache.mesos.MesosNativeLibrary
import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.executor.TriggerThreadDump import org.apache.spark.executor.TriggerThreadDump
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat,
FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._ import org.apache.spark.rdd._
import org.apache.spark.scheduler._ import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
SparkDeploySchedulerBackend, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._ import org.apache.spark.storage._
...@@ -1016,12 +1024,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli ...@@ -1016,12 +1024,48 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs, * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(fileName)` to find its download location. * use `SparkFiles.get(fileName)` to find its download location.
*/ */
def addFile(path: String) { def addFile(path: String): Unit = {
addFile(path, false)
}
/**
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(fileName)` to find its download location.
*
* A directory can be given if the recursive option is set to true. Currently directories are only
* supported for Hadoop-supported filesystems.
*/
def addFile(path: String, recursive: Boolean): Unit = {
val uri = new URI(path) val uri = new URI(path)
val key = uri.getScheme match { val schemeCorrectedPath = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath)) case null | "local" => "file:" + uri.getPath
case "local" => "file:" + uri.getPath case _ => path
case _ => path }
val hadoopPath = new Path(schemeCorrectedPath)
val scheme = new URI(schemeCorrectedPath).getScheme
if (!Array("http", "https", "ftp").contains(scheme)) {
val fs = hadoopPath.getFileSystem(hadoopConfiguration)
if (!fs.exists(hadoopPath)) {
throw new FileNotFoundException(s"Added file $hadoopPath does not exist.")
}
val isDir = fs.isDirectory(hadoopPath)
if (!isLocal && scheme == "file" && isDir) {
throw new SparkException(s"addFile does not support local directories when not running " +
"local mode.")
}
if (!recursive && isDir) {
throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " +
"turned on.")
}
}
val key = if (!isLocal && scheme == "file") {
env.httpFileServer.addFile(new File(uri.getPath))
} else {
schemeCorrectedPath
} }
val timestamp = System.currentTimeMillis val timestamp = System.currentTimeMillis
addedFiles(key) = timestamp addedFiles(key) = timestamp
...@@ -1633,8 +1677,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli ...@@ -1633,8 +1677,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val schedulingMode = getSchedulingMode.toString val schedulingMode = getSchedulingMode.toString
val addedJarPaths = addedJars.keys.toSeq val addedJarPaths = addedJars.keys.toSeq
val addedFilePaths = addedFiles.keys.toSeq val addedFilePaths = addedFiles.keys.toSeq
val environmentDetails = val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths,
SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths) addedFilePaths)
val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails) val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails)
listenerBus.post(environmentUpdate) listenerBus.post(environmentUpdate)
} }
......
...@@ -386,8 +386,10 @@ private[spark] object Utils extends Logging { ...@@ -386,8 +386,10 @@ private[spark] object Utils extends Logging {
} }
/** /**
* Download a file to target directory. Supports fetching the file in a variety of ways, * Download a file or directory to target directory. Supports fetching the file in a variety of
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based
* on the URL parameter. Fetching directories is only supported from Hadoop-compatible
* filesystems.
* *
* If `useCache` is true, first attempts to fetch the file to a local cache that's shared * If `useCache` is true, first attempts to fetch the file to a local cache that's shared
* across executors running the same application. `useCache` is used mainly for * across executors running the same application. `useCache` is used mainly for
...@@ -456,7 +458,6 @@ private[spark] object Utils extends Logging { ...@@ -456,7 +458,6 @@ private[spark] object Utils extends Logging {
* *
* @param url URL that `sourceFile` originated from, for logging purposes. * @param url URL that `sourceFile` originated from, for logging purposes.
* @param in InputStream to download. * @param in InputStream to download.
* @param tempFile File path to download `in` to.
* @param destFile File path to move `tempFile` to. * @param destFile File path to move `tempFile` to.
* @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match * @param fileOverwrite Whether to delete/overwrite an existing `destFile` that does not match
* `sourceFile` * `sourceFile`
...@@ -464,9 +465,11 @@ private[spark] object Utils extends Logging { ...@@ -464,9 +465,11 @@ private[spark] object Utils extends Logging {
private def downloadFile( private def downloadFile(
url: String, url: String,
in: InputStream, in: InputStream,
tempFile: File,
destFile: File, destFile: File,
fileOverwrite: Boolean): Unit = { fileOverwrite: Boolean): Unit = {
val tempFile = File.createTempFile("fetchFileTemp", null,
new File(destFile.getParentFile.getAbsolutePath))
logInfo(s"Fetching $url to $tempFile")
try { try {
val out = new FileOutputStream(tempFile) val out = new FileOutputStream(tempFile)
...@@ -505,7 +508,7 @@ private[spark] object Utils extends Logging { ...@@ -505,7 +508,7 @@ private[spark] object Utils extends Logging {
removeSourceFile: Boolean = false): Unit = { removeSourceFile: Boolean = false): Unit = {
if (destFile.exists) { if (destFile.exists) {
if (!Files.equal(sourceFile, destFile)) { if (!filesEqualRecursive(sourceFile, destFile)) {
if (fileOverwrite) { if (fileOverwrite) {
logInfo( logInfo(
s"File $destFile exists and does not match contents of $url, replacing it with $url" s"File $destFile exists and does not match contents of $url, replacing it with $url"
...@@ -540,13 +543,44 @@ private[spark] object Utils extends Logging { ...@@ -540,13 +543,44 @@ private[spark] object Utils extends Logging {
Files.move(sourceFile, destFile) Files.move(sourceFile, destFile)
} else { } else {
logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}") logInfo(s"Copying ${sourceFile.getAbsolutePath} to ${destFile.getAbsolutePath}")
Files.copy(sourceFile, destFile) copyRecursive(sourceFile, destFile)
}
}
private def filesEqualRecursive(file1: File, file2: File): Boolean = {
if (file1.isDirectory && file2.isDirectory) {
val subfiles1 = file1.listFiles()
val subfiles2 = file2.listFiles()
if (subfiles1.size != subfiles2.size) {
return false
}
subfiles1.sortBy(_.getName).zip(subfiles2.sortBy(_.getName)).forall {
case (f1, f2) => filesEqualRecursive(f1, f2)
}
} else if (file1.isFile && file2.isFile) {
Files.equal(file1, file2)
} else {
false
}
}
private def copyRecursive(source: File, dest: File): Unit = {
if (source.isDirectory) {
if (!dest.mkdir()) {
throw new IOException(s"Failed to create directory ${dest.getPath}")
}
val subfiles = source.listFiles()
subfiles.foreach(f => copyRecursive(f, new File(dest, f.getName)))
} else {
Files.copy(source, dest)
} }
} }
/** /**
* Download a file to target directory. Supports fetching the file in a variety of ways, * Download a file or directory to target directory. Supports fetching the file in a variety of
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based
* on the URL parameter. Fetching directories is only supported from Hadoop-compatible
* filesystems.
* *
* Throws SparkException if the target file already exists and has different contents than * Throws SparkException if the target file already exists and has different contents than
* the requested file. * the requested file.
...@@ -558,14 +592,11 @@ private[spark] object Utils extends Logging { ...@@ -558,14 +592,11 @@ private[spark] object Utils extends Logging {
conf: SparkConf, conf: SparkConf,
securityMgr: SecurityManager, securityMgr: SecurityManager,
hadoopConf: Configuration) { hadoopConf: Configuration) {
val tempFile = File.createTempFile("fetchFileTemp", null, new File(targetDir.getAbsolutePath))
val targetFile = new File(targetDir, filename) val targetFile = new File(targetDir, filename)
val uri = new URI(url) val uri = new URI(url)
val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false)
Option(uri.getScheme).getOrElse("file") match { Option(uri.getScheme).getOrElse("file") match {
case "http" | "https" | "ftp" => case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + tempFile)
var uc: URLConnection = null var uc: URLConnection = null
if (securityMgr.isAuthenticationEnabled()) { if (securityMgr.isAuthenticationEnabled()) {
logDebug("fetchFile with security enabled") logDebug("fetchFile with security enabled")
...@@ -583,17 +614,44 @@ private[spark] object Utils extends Logging { ...@@ -583,17 +614,44 @@ private[spark] object Utils extends Logging {
uc.setReadTimeout(timeout) uc.setReadTimeout(timeout)
uc.connect() uc.connect()
val in = uc.getInputStream() val in = uc.getInputStream()
downloadFile(url, in, tempFile, targetFile, fileOverwrite) downloadFile(url, in, targetFile, fileOverwrite)
case "file" => case "file" =>
// In the case of a local file, copy the local file to the target directory. // In the case of a local file, copy the local file to the target directory.
// Note the difference between uri vs url. // Note the difference between uri vs url.
val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url) val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
copyFile(url, sourceFile, targetFile, fileOverwrite) copyFile(url, sourceFile, targetFile, fileOverwrite)
case _ => case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val fs = getHadoopFileSystem(uri, hadoopConf) val fs = getHadoopFileSystem(uri, hadoopConf)
val in = fs.open(new Path(uri)) val path = new Path(uri)
downloadFile(url, in, tempFile, targetFile, fileOverwrite) fetchHcfsFile(path, new File(targetDir, path.getName), fs, conf, hadoopConf, fileOverwrite)
}
}
/**
* Fetch a file or directory from a Hadoop-compatible filesystem.
*
* Visible for testing
*/
private[spark] def fetchHcfsFile(
path: Path,
targetDir: File,
fs: FileSystem,
conf: SparkConf,
hadoopConf: Configuration,
fileOverwrite: Boolean): Unit = {
if (!targetDir.mkdir()) {
throw new IOException(s"Failed to create directory ${targetDir.getPath}")
}
fs.listStatus(path).foreach { fileStatus =>
val innerPath = fileStatus.getPath
if (fileStatus.isDir) {
fetchHcfsFile(innerPath, new File(targetDir, innerPath.getName), fs, conf, hadoopConf,
fileOverwrite)
} else {
val in = fs.open(innerPath)
val targetFile = new File(targetDir, innerPath.getName)
downloadFile(innerPath.toString, in, targetFile, fileOverwrite)
}
} }
} }
......
...@@ -17,10 +17,17 @@ ...@@ -17,10 +17,17 @@
package org.apache.spark package org.apache.spark
import java.io.File
import com.google.common.base.Charsets._
import com.google.common.io.Files
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.BytesWritable
import org.apache.spark.util.Utils
class SparkContextSuite extends FunSuite with LocalSparkContext { class SparkContextSuite extends FunSuite with LocalSparkContext {
test("Only one SparkContext may be active at a time") { test("Only one SparkContext may be active at a time") {
...@@ -72,4 +79,74 @@ class SparkContextSuite extends FunSuite with LocalSparkContext { ...@@ -72,4 +79,74 @@ class SparkContextSuite extends FunSuite with LocalSparkContext {
val byteArray2 = converter.convert(bytesWritable) val byteArray2 = converter.convert(bytesWritable)
assert(byteArray2.length === 0) assert(byteArray2.length === 0)
} }
test("addFile works") {
val file = File.createTempFile("someprefix", "somesuffix")
val absolutePath = file.getAbsolutePath
try {
Files.write("somewords", file, UTF_8)
val length = file.length()
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
sc.addFile(file.getAbsolutePath)
sc.parallelize(Array(1), 1).map(x => {
val gotten = new File(SparkFiles.get(file.getName))
if (!gotten.exists()) {
throw new SparkException("file doesn't exist")
}
if (length != gotten.length()) {
throw new SparkException(
s"file has different length $length than added file ${gotten.length()}")
}
if (absolutePath == gotten.getAbsolutePath) {
throw new SparkException("file should have been copied")
}
x
}).count()
} finally {
sc.stop()
}
}
test("addFile recursive works") {
val pluto = Utils.createTempDir()
val neptune = Utils.createTempDir(pluto.getAbsolutePath)
val saturn = Utils.createTempDir(neptune.getAbsolutePath)
val alien1 = File.createTempFile("alien", "1", neptune)
val alien2 = File.createTempFile("alien", "2", saturn)
try {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
sc.addFile(neptune.getAbsolutePath, true)
sc.parallelize(Array(1), 1).map(x => {
val sep = File.separator
if (!new File(SparkFiles.get(neptune.getName + sep + alien1.getName)).exists()) {
throw new SparkException("can't access file under root added directory")
}
if (!new File(SparkFiles.get(neptune.getName + sep + saturn.getName + sep + alien2.getName))
.exists()) {
throw new SparkException("can't access file in nested directory")
}
if (new File(SparkFiles.get(pluto.getName + sep + neptune.getName + sep + alien1.getName))
.exists()) {
throw new SparkException("file exists that shouldn't")
}
x
}).count()
} finally {
sc.stop()
}
}
test("addFile recursive can't add directories by default") {
val dir = Utils.createTempDir()
try {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
intercept[SparkException] {
sc.addFile(dir.getAbsolutePath)
}
} finally {
sc.stop()
}
}
} }
...@@ -29,6 +29,9 @@ import com.google.common.base.Charsets.UTF_8 ...@@ -29,6 +29,9 @@ import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files import com.google.common.io.Files
import org.scalatest.FunSuite import org.scalatest.FunSuite
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkConf import org.apache.spark.SparkConf
class UtilsSuite extends FunSuite with ResetSystemProperties { class UtilsSuite extends FunSuite with ResetSystemProperties {
...@@ -381,4 +384,32 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { ...@@ -381,4 +384,32 @@ class UtilsSuite extends FunSuite with ResetSystemProperties {
require(cnt === 2, "prepare should be called twice") require(cnt === 2, "prepare should be called twice")
require(time < 500, "preparation time should not count") require(time < 500, "preparation time should not count")
} }
test("fetch hcfs dir") {
val tempDir = Utils.createTempDir()
val innerTempDir = Utils.createTempDir(tempDir.getPath)
val tempFile = File.createTempFile("someprefix", "somesuffix", innerTempDir)
val targetDir = new File("target-dir")
Files.write("some text", tempFile, UTF_8)
try {
val path = new Path("file://" + tempDir.getAbsolutePath)
val conf = new Configuration()
val fs = Utils.getHadoopFileSystem(path.toString, conf)
Utils.fetchHcfsFile(path, targetDir, fs, new SparkConf(), conf, false)
assert(targetDir.exists())
assert(targetDir.isDirectory())
val newInnerDir = new File(targetDir, innerTempDir.getName)
println("inner temp dir: " + innerTempDir.getName)
targetDir.listFiles().map(_.getName).foreach(println)
assert(newInnerDir.exists())
assert(newInnerDir.isDirectory())
val newInnerFile = new File(newInnerDir, tempFile.getName)
assert(newInnerFile.exists())
assert(newInnerFile.isFile())
} finally {
Utils.deleteRecursively(tempDir)
Utils.deleteRecursively(targetDir)
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment