diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 77005aa9040b5e6a2bb429214b9d0088abff4243..c60a2a1706d5a80686c91ed872f55f10cdd6fb4a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy import java.io.{File, IOException} import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL +import java.nio.file.Files import java.security.PrivilegedExceptionAction import java.text.ParseException @@ -28,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties import org.apache.commons.lang3.StringUtils -import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.ivy.Ivy import org.apache.ivy.core.LogOptions @@ -308,6 +310,15 @@ object SparkSubmit extends CommandLineUtils { RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } + // In client mode, download remote files. + if (deployMode == CLIENT) { + val hadoopConf = new HadoopConfiguration() + args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull + args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull + args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull + } + // Require all python files to be local, so we can add them to the PYTHONPATH // In YARN cluster mode, python files are distributed as regular files, which can be non-local. // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. @@ -825,6 +836,41 @@ object SparkSubmit extends CommandLineUtils { .mkString(",") if (merged == "") null else merged } + + /** + * Download a list of remote files to temp local files. If the file is local, the original file + * will be returned. + * @param fileList A comma separated file list. + * @return A comma separated local files list. + */ + private[deploy] def downloadFileList( + fileList: String, + hadoopConf: HadoopConfiguration): String = { + require(fileList != null, "fileList cannot be null.") + fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",") + } + + /** + * Download a file from the remote to a local temporary directory. If the input path points to + * a local path, returns it with no operation. + */ + private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = { + require(path != null, "path cannot be null.") + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "file" | "local" => + path + + case _ => + val fs = FileSystem.get(uri, hadoopConf) + val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath) + // scalastyle:off println + printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.") + // scalastyle:on println + fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath)) + Utils.resolveURI(tmpFile.getAbsolutePath).toString + } + } } /** Provides utility functions to be used inside SparkSubmit. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index a43839a8815f9986706d95398939ac412f46cb23..6e9721c45931a5a3893014a7cf282d86e8de9d94 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark.deploy import java.io._ +import java.net.URI import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.commons.io.{FilenameUtils, FileUtils} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts @@ -535,7 +538,7 @@ class SparkSubmitSuite test("resolves command line argument paths correctly") { val jars = "/jar1,/jar2" // --jars - val files = "hdfs:/file1,file2" // --files + val files = "local:/file1,file2" // --files val archives = "file:/archive1,archive2" // --archives val pyFiles = "py-file1,py-file2" // --py-files @@ -587,7 +590,7 @@ class SparkSubmitSuite test("resolves config paths correctly") { val jars = "/jar1,/jar2" // spark.jars - val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files + val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files val archives = "file:/archive1,archive2" // spark.yarn.dist.archives val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles @@ -705,6 +708,87 @@ class SparkSubmitSuite } // scalastyle:on println + private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { + if (sourcePath == outputPath) { + return + } + + val sourceUri = new URI(sourcePath) + val outputUri = new URI(outputPath) + assert(outputUri.getScheme === "file") + + // The path and filename are preserved. + assert(outputUri.getPath.endsWith(sourceUri.getPath)) + assert(FileUtils.readFileToString(new File(outputUri.getPath)) === + FileUtils.readFileToString(new File(sourceUri.getPath))) + } + + private def deleteTempOutputFile(outputPath: String): Unit = { + val outputFile = new File(new URI(outputPath).getPath) + if (outputFile.exists) { + outputFile.delete() + } + } + + test("downloadFile - invalid url") { + intercept[IOException] { + SparkSubmit.downloadFile("abc:/my/file", new Configuration()) + } + } + + test("downloadFile - file doesn't exist") { + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + intercept[FileNotFoundException] { + SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf) + } + } + + test("downloadFile does not download local file") { + // empty path is considered as local file. + assert(SparkSubmit.downloadFile("", new Configuration()) === "") + assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file") + } + + test("download one file to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePath = s"s3a://${jarFile.getAbsolutePath}" + val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf) + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + + test("download list of files to local") { + val jarFile = File.createTempFile("test", ".jar") + jarFile.deleteOnExit() + val content = "hello, world" + FileUtils.write(jarFile, content) + val hadoopConf = new Configuration() + // Set s3a implementation to local file system for testing. + hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") + // Disable file system impl cache to make sure the test file system is picked up. + hadoopConf.set("fs.s3a.impl.disable.cache", "true") + val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") + val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",") + + assert(outputPaths.length === sourcePaths.length) + sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => + checkDownloadedFile(sourcePath, outputPath) + deleteTempOutputFile(outputPath) + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) @@ -807,3 +891,10 @@ object UserClasspathFirstTest { } } } + +class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { + override def copyToLocalFile(src: Path, dst: Path): Unit = { + // Ignore the scheme for testing. + super.copyToLocalFile(new Path(src.toUri.getPath), dst) + } +}