diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index d13fb4193970bf7acfbb5922c5bc8f0383a7a164..abde04062c4b1232fd868aecddd17a172c5030d4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -17,17 +17,21 @@ package org.apache.spark.deploy -import java.io.{File, IOException} +import java.io._ import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException} import java.net.URL import java.nio.file.Files -import java.security.PrivilegedExceptionAction +import java.security.{KeyStore, PrivilegedExceptionAction} +import java.security.cert.X509Certificate import java.text.ParseException +import javax.net.ssl._ import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.Properties +import com.google.common.io.ByteStreams +import org.apache.commons.io.FileUtils import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} import org.apache.hadoop.fs.{FileSystem, Path} @@ -310,33 +314,33 @@ object SparkSubmit extends CommandLineUtils { RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose) } - // In client mode, download remote files. - if (deployMode == CLIENT) { - val hadoopConf = new HadoopConfiguration() - args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull - args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull - args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull - args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull - } - - // Require all python files to be local, so we can add them to the PYTHONPATH - // In YARN cluster mode, python files are distributed as regular files, which can be non-local. - // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos. - if (args.isPython && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}") + val hadoopConf = new HadoopConfiguration() + val targetDir = Files.createTempDirectory("tmp").toFile + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(new Thread() { + override def run(): Unit = { + FileUtils.deleteQuietly(targetDir) } - val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",") - if (nonLocalPyFiles.nonEmpty) { - printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles") - } - } + }) + // scalastyle:on runtimeaddshutdownhook - // Require all R files to be local - if (args.isR && !isYarnCluster && !isMesosCluster) { - if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) { - printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}") - } + // Resolve glob path for different resources. + args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull + args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull + args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull + args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + + // In client mode, download remote files. + if (deployMode == CLIENT) { + args.primaryResource = Option(args.primaryResource).map { + downloadFile(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull + args.jars = Option(args.jars).map { + downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull + args.pyFiles = Option(args.pyFiles).map { + downloadFileList(_, targetDir, args.sparkProperties, hadoopConf) + }.orNull } // The following modes are not supported or applicable @@ -841,36 +845,132 @@ object SparkSubmit extends CommandLineUtils { * Download a list of remote files to temp local files. If the file is local, the original file * will be returned. * @param fileList A comma separated file list. + * @param targetDir A temporary directory for which downloaded files + * @param sparkProperties Spark properties * @return A comma separated local files list. */ private[deploy] def downloadFileList( fileList: String, + targetDir: File, + sparkProperties: Map[String, String], hadoopConf: HadoopConfiguration): String = { require(fileList != null, "fileList cannot be null.") - fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",") + fileList.split(",") + .map(downloadFile(_, targetDir, sparkProperties, hadoopConf)) + .mkString(",") } /** * Download a file from the remote to a local temporary directory. If the input path points to * a local path, returns it with no operation. + * @param path A file path from where the files will be downloaded. + * @param targetDir A temporary directory for which downloaded files + * @param sparkProperties Spark properties + * @return A comma separated local files list. */ - private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = { + private[deploy] def downloadFile( + path: String, + targetDir: File, + sparkProperties: Map[String, String], + hadoopConf: HadoopConfiguration): String = { require(path != null, "path cannot be null.") val uri = Utils.resolveURI(path) uri.getScheme match { - case "file" | "local" => - path + case "file" | "local" => path + case "http" | "https" | "ftp" => + val uc = uri.toURL.openConnection() + uc match { + case https: HttpsURLConnection => + val trustStore = sparkProperties.get("spark.ssl.fs.trustStore") + .orElse(sparkProperties.get("spark.ssl.trustStore")) + val trustStorePwd = sparkProperties.get("spark.ssl.fs.trustStorePassword") + .orElse(sparkProperties.get("spark.ssl.trustStorePassword")) + .map(_.toCharArray) + .orNull + val protocol = sparkProperties.get("spark.ssl.fs.protocol") + .orElse(sparkProperties.get("spark.ssl.protocol")) + if (protocol.isEmpty) { + printErrorAndExit("spark ssl protocol is required when enabling SSL connection.") + } + + val trustStoreManagers = trustStore.map { t => + var input: InputStream = null + try { + input = new FileInputStream(new File(t)) + val ks = KeyStore.getInstance(KeyStore.getDefaultType) + ks.load(input, trustStorePwd) + val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + tmf.init(ks) + tmf.getTrustManagers + } finally { + if (input != null) { + input.close() + input = null + } + } + }.getOrElse { + Array({ + new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + override def checkClientTrusted( + x509Certificates: Array[X509Certificate], s: String) {} + override def checkServerTrusted( + x509Certificates: Array[X509Certificate], s: String) {} + }: TrustManager + }) + } + val sslContext = SSLContext.getInstance(protocol.get) + sslContext.init(null, trustStoreManagers, null) + https.setSSLSocketFactory(sslContext.getSocketFactory) + https.setHostnameVerifier(new HostnameVerifier { + override def verify(s: String, sslSession: SSLSession): Boolean = false + }) + + case _ => + } + uc.setConnectTimeout(60 * 1000) + uc.setReadTimeout(60 * 1000) + uc.connect() + val in = uc.getInputStream + val fileName = new Path(uri).getName + val tempFile = new File(targetDir, fileName) + val out = new FileOutputStream(tempFile) + // scalastyle:off println + printStream.println(s"Downloading ${uri.toString} to ${tempFile.getAbsolutePath}.") + // scalastyle:on println + try { + ByteStreams.copy(in, out) + } finally { + in.close() + out.close() + } + tempFile.toURI.toString case _ => val fs = FileSystem.get(uri, hadoopConf) - val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath) + val tmpFile = new File(targetDir, new Path(uri).getName) // scalastyle:off println printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.") // scalastyle:on println fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath)) - Utils.resolveURI(tmpFile.getAbsolutePath).toString + tmpFile.toURI.toString } } + + private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = { + require(paths != null, "paths cannot be null.") + paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => + val uri = Utils.resolveURI(path) + uri.getScheme match { + case "local" | "http" | "https" | "ftp" => Array(path) + case _ => + val fs = FileSystem.get(uri, hadoopConf) + Option(fs.globStatus(new Path(uri))).map { status => + status.filter(_.isFile).map(_.getPath.toUri.toString) + }.getOrElse(Array(path)) + } + }.mkString(",") + } } /** Provides utility functions to be used inside SparkSubmit. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 7800d3d624e3e1cef4bfc013cd2a40624cc2ed4e..fd1521193fdeea790cd223535f3ff66fd3f57ebb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -520,7 +520,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | (Default: client). | --class CLASS_NAME Your application's main class (for Java / Scala apps). | --name NAME A name of your application. - | --jars JARS Comma-separated list of local jars to include on the driver + | --jars JARS Comma-separated list of jars to include on the driver | and executor classpaths. | --packages Comma-separated list of maven coordinates of jars to include | on the driver and executor classpaths. Will search the local diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index b089357e7b868ca2deb45103d09de5ddf8a762f4..97357cdbb60838326849d745b3405e1919d7be6e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.deploy import java.io._ import java.net.URI import java.nio.charset.StandardCharsets +import java.nio.file.Files +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.io.Source import com.google.common.io.ByteStreams -import org.apache.commons.io.{FilenameUtils, FileUtils} +import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} @@ -42,7 +44,6 @@ import org.apache.spark.TestUtils.JavaSourceFromString import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} - trait TestPrematureExit { suite: SparkFunSuite => @@ -726,6 +727,47 @@ class SparkSubmitSuite Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar")) Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar")) } + + test("support glob path") { + val tmpJarDir = Utils.createTempDir() + val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir) + val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir) + + val tmpFileDir = Utils.createTempDir() + val file1 = File.createTempFile("tmpFile1", "", tmpFileDir) + val file2 = File.createTempFile("tmpFile2", "", tmpFileDir) + + val tmpPyFileDir = Utils.createTempDir() + val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir) + val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir) + + val tmpArchiveDir = Utils.createTempDir() + val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir) + val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir) + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"${tmpJarDir.getAbsolutePath}/*.jar", + "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*", + "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*", + "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip", + jar2.toString) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + sysProps("spark.yarn.dist.jars").split(",").toSet should be + (Set(jar1.toURI.toString, jar2.toURI.toString)) + sysProps("spark.yarn.dist.files").split(",").toSet should be + (Set(file1.toURI.toString, file2.toURI.toString)) + sysProps("spark.submit.pyFiles").split(",").toSet should be + (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) + sysProps("spark.yarn.dist.archives").split(",").toSet should be + (Set(archive1.toURI.toString, archive2.toURI.toString)) + } + // scalastyle:on println private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = { @@ -738,7 +780,7 @@ class SparkSubmitSuite assert(outputUri.getScheme === "file") // The path and filename are preserved. - assert(outputUri.getPath.endsWith(sourceUri.getPath)) + assert(outputUri.getPath.endsWith(new Path(sourceUri).getName)) assert(FileUtils.readFileToString(new File(outputUri.getPath)) === FileUtils.readFileToString(new File(sourceUri.getPath))) } @@ -752,25 +794,29 @@ class SparkSubmitSuite test("downloadFile - invalid url") { intercept[IOException] { - SparkSubmit.downloadFile("abc:/my/file", new Configuration()) + SparkSubmit.downloadFile( + "abc:/my/file", Utils.createTempDir(), mutable.Map.empty, new Configuration()) } } test("downloadFile - file doesn't exist") { val hadoopConf = new Configuration() + val tmpDir = Utils.createTempDir() // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") intercept[FileNotFoundException] { - SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf) + SparkSubmit.downloadFile("s3a:/no/such/file", tmpDir, mutable.Map.empty, hadoopConf) } } test("downloadFile does not download local file") { // empty path is considered as local file. - assert(SparkSubmit.downloadFile("", new Configuration()) === "") - assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file") + val tmpDir = Files.createTempDirectory("tmp").toFile + assert(SparkSubmit.downloadFile("", tmpDir, mutable.Map.empty, new Configuration()) === "") + assert(SparkSubmit.downloadFile("/local/file", tmpDir, mutable.Map.empty, + new Configuration()) === "/local/file") } test("download one file to local") { @@ -779,12 +825,14 @@ class SparkSubmitSuite val content = "hello, world" FileUtils.write(jarFile, content) val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") val sourcePath = s"s3a://${jarFile.getAbsolutePath}" - val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf) + val outputPath = + SparkSubmit.downloadFile(sourcePath, tmpDir, mutable.Map.empty, hadoopConf) checkDownloadedFile(sourcePath, outputPath) deleteTempOutputFile(outputPath) } @@ -795,12 +843,14 @@ class SparkSubmitSuite val content = "hello, world" FileUtils.write(jarFile, content) val hadoopConf = new Configuration() + val tmpDir = Files.createTempDirectory("tmp").toFile // Set s3a implementation to local file system for testing. hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem") // Disable file system impl cache to make sure the test file system is picked up. hadoopConf.set("fs.s3a.impl.disable.cache", "true") val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}") - val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",") + val outputPaths = SparkSubmit.downloadFileList( + sourcePaths.mkString(","), tmpDir, mutable.Map.empty, hadoopConf).split(",") assert(outputPaths.length === sourcePaths.length) sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) => diff --git a/docs/configuration.md b/docs/configuration.md index c785a664c67b1e48566a07006a55f3675bf01284..7dc23e441a7bad8962df904ae85afa452f8f9529 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -422,21 +422,21 @@ Apart from these, the following properties are also available, and may be useful <td><code>spark.files</code></td> <td></td> <td> - Comma-separated list of files to be placed in the working directory of each executor. + Comma-separated list of files to be placed in the working directory of each executor. Globs are allowed. </td> </tr> <tr> <td><code>spark.submit.pyFiles</code></td> <td></td> <td> - Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. + Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. Globs are allowed. </td> </tr> <tr> <td><code>spark.jars</code></td> <td></td> <td> - Comma-separated list of local jars to include on the driver and executor classpaths. + Comma-separated list of jars to include on the driver and executor classpaths. Globs are allowed. </td> </tr> <tr>