From 5800144a54f5c0180ccf67392f32c3e8a51119b1 Mon Sep 17 00:00:00 2001
From: jerryshao <sshao@hortonworks.com>
Date: Thu, 6 Jul 2017 15:32:49 +0800
Subject: [PATCH] [SPARK-21012][SUBMIT] Add glob support for resources adding
 to Spark

Current "--jars (spark.jars)", "--files (spark.files)", "--py-files (spark.submit.pyFiles)" and "--archives (spark.yarn.dist.archives)" only support non-glob path. This is OK for most of the cases, but when user requires to add more jars, files into Spark, it is too verbose to list one by one. So here propose to add glob path support for resources.

Also improving the code of downloading resources.

## How was this patch tested?

UT added, also verified manually in local cluster.

Author: jerryshao <sshao@hortonworks.com>

Closes #18235 from jerryshao/SPARK-21012.
---
 .../org/apache/spark/deploy/SparkSubmit.scala | 166 ++++++++++++++----
 .../spark/deploy/SparkSubmitArguments.scala   |   2 +-
 .../spark/deploy/SparkSubmitSuite.scala       |  68 ++++++-
 docs/configuration.md                         |   6 +-
 4 files changed, 196 insertions(+), 46 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index d13fb41939..abde04062c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -17,17 +17,21 @@
 
 package org.apache.spark.deploy
 
-import java.io.{File, IOException}
+import java.io._
 import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
 import java.net.URL
 import java.nio.file.Files
-import java.security.PrivilegedExceptionAction
+import java.security.{KeyStore, PrivilegedExceptionAction}
+import java.security.cert.X509Certificate
 import java.text.ParseException
+import javax.net.ssl._
 
 import scala.annotation.tailrec
 import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
 import scala.util.Properties
 
+import com.google.common.io.ByteStreams
+import org.apache.commons.io.FileUtils
 import org.apache.commons.lang3.StringUtils
 import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
 import org.apache.hadoop.fs.{FileSystem, Path}
@@ -310,33 +314,33 @@ object SparkSubmit extends CommandLineUtils {
       RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose)
     }
 
-    // In client mode, download remote files.
-    if (deployMode == CLIENT) {
-      val hadoopConf = new HadoopConfiguration()
-      args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull
-      args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull
-      args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull
-      args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull
-    }
-
-    // Require all python files to be local, so we can add them to the PYTHONPATH
-    // In YARN cluster mode, python files are distributed as regular files, which can be non-local.
-    // In Mesos cluster mode, non-local python files are automatically downloaded by Mesos.
-    if (args.isPython && !isYarnCluster && !isMesosCluster) {
-      if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
-        printErrorAndExit(s"Only local python files are supported: ${args.primaryResource}")
+    val hadoopConf = new HadoopConfiguration()
+    val targetDir = Files.createTempDirectory("tmp").toFile
+    // scalastyle:off runtimeaddshutdownhook
+    Runtime.getRuntime.addShutdownHook(new Thread() {
+      override def run(): Unit = {
+        FileUtils.deleteQuietly(targetDir)
       }
-      val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",")
-      if (nonLocalPyFiles.nonEmpty) {
-        printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles")
-      }
-    }
+    })
+    // scalastyle:on runtimeaddshutdownhook
 
-    // Require all R files to be local
-    if (args.isR && !isYarnCluster && !isMesosCluster) {
-      if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
-        printErrorAndExit(s"Only local R files are supported: ${args.primaryResource}")
-      }
+    // Resolve glob path for different resources.
+    args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull
+    args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull
+    args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull
+    args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull
+
+    // In client mode, download remote files.
+    if (deployMode == CLIENT) {
+      args.primaryResource = Option(args.primaryResource).map {
+        downloadFile(_, targetDir, args.sparkProperties, hadoopConf)
+      }.orNull
+      args.jars = Option(args.jars).map {
+        downloadFileList(_, targetDir, args.sparkProperties, hadoopConf)
+      }.orNull
+      args.pyFiles = Option(args.pyFiles).map {
+        downloadFileList(_, targetDir, args.sparkProperties, hadoopConf)
+      }.orNull
     }
 
     // The following modes are not supported or applicable
@@ -841,36 +845,132 @@ object SparkSubmit extends CommandLineUtils {
    * Download a list of remote files to temp local files. If the file is local, the original file
    * will be returned.
    * @param fileList A comma separated file list.
+   * @param targetDir A temporary directory for which downloaded files
+   * @param sparkProperties Spark properties
    * @return A comma separated local files list.
    */
   private[deploy] def downloadFileList(
       fileList: String,
+      targetDir: File,
+      sparkProperties: Map[String, String],
       hadoopConf: HadoopConfiguration): String = {
     require(fileList != null, "fileList cannot be null.")
-    fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",")
+    fileList.split(",")
+      .map(downloadFile(_, targetDir, sparkProperties, hadoopConf))
+      .mkString(",")
   }
 
   /**
    * Download a file from the remote to a local temporary directory. If the input path points to
    * a local path, returns it with no operation.
+   * @param path A file path from where the files will be downloaded.
+   * @param targetDir A temporary directory for which downloaded files
+   * @param sparkProperties Spark properties
+   * @return A comma separated local files list.
    */
-  private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = {
+  private[deploy] def downloadFile(
+      path: String,
+      targetDir: File,
+      sparkProperties: Map[String, String],
+      hadoopConf: HadoopConfiguration): String = {
     require(path != null, "path cannot be null.")
     val uri = Utils.resolveURI(path)
     uri.getScheme match {
-      case "file" | "local" =>
-        path
+      case "file" | "local" => path
+      case "http" | "https" | "ftp" =>
+        val uc = uri.toURL.openConnection()
+        uc match {
+          case https: HttpsURLConnection =>
+            val trustStore = sparkProperties.get("spark.ssl.fs.trustStore")
+              .orElse(sparkProperties.get("spark.ssl.trustStore"))
+            val trustStorePwd = sparkProperties.get("spark.ssl.fs.trustStorePassword")
+              .orElse(sparkProperties.get("spark.ssl.trustStorePassword"))
+              .map(_.toCharArray)
+              .orNull
+            val protocol = sparkProperties.get("spark.ssl.fs.protocol")
+              .orElse(sparkProperties.get("spark.ssl.protocol"))
+            if (protocol.isEmpty) {
+              printErrorAndExit("spark ssl protocol is required when enabling SSL connection.")
+            }
+
+            val trustStoreManagers = trustStore.map { t =>
+              var input: InputStream = null
+              try {
+                input = new FileInputStream(new File(t))
+                val ks = KeyStore.getInstance(KeyStore.getDefaultType)
+                ks.load(input, trustStorePwd)
+                val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
+                tmf.init(ks)
+                tmf.getTrustManagers
+              } finally {
+                if (input != null) {
+                  input.close()
+                  input = null
+                }
+              }
+            }.getOrElse {
+              Array({
+                new X509TrustManager {
+                  override def getAcceptedIssuers: Array[X509Certificate] = null
+                  override def checkClientTrusted(
+                      x509Certificates: Array[X509Certificate], s: String) {}
+                  override def checkServerTrusted(
+                      x509Certificates: Array[X509Certificate], s: String) {}
+                }: TrustManager
+              })
+            }
+            val sslContext = SSLContext.getInstance(protocol.get)
+            sslContext.init(null, trustStoreManagers, null)
+            https.setSSLSocketFactory(sslContext.getSocketFactory)
+            https.setHostnameVerifier(new HostnameVerifier {
+              override def verify(s: String, sslSession: SSLSession): Boolean = false
+            })
+
+          case _ =>
+        }
 
+        uc.setConnectTimeout(60 * 1000)
+        uc.setReadTimeout(60 * 1000)
+        uc.connect()
+        val in = uc.getInputStream
+        val fileName = new Path(uri).getName
+        val tempFile = new File(targetDir, fileName)
+        val out = new FileOutputStream(tempFile)
+        // scalastyle:off println
+        printStream.println(s"Downloading ${uri.toString} to ${tempFile.getAbsolutePath}.")
+        // scalastyle:on println
+        try {
+          ByteStreams.copy(in, out)
+        } finally {
+          in.close()
+          out.close()
+        }
+        tempFile.toURI.toString
       case _ =>
         val fs = FileSystem.get(uri, hadoopConf)
-        val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath)
+        val tmpFile = new File(targetDir, new Path(uri).getName)
         // scalastyle:off println
         printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.")
         // scalastyle:on println
         fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath))
-        Utils.resolveURI(tmpFile.getAbsolutePath).toString
+        tmpFile.toURI.toString
     }
   }
+
+  private def resolveGlobPaths(paths: String, hadoopConf: HadoopConfiguration): String = {
+    require(paths != null, "paths cannot be null.")
+    paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path =>
+      val uri = Utils.resolveURI(path)
+      uri.getScheme match {
+        case "local" | "http" | "https" | "ftp" => Array(path)
+        case _ =>
+          val fs = FileSystem.get(uri, hadoopConf)
+          Option(fs.globStatus(new Path(uri))).map { status =>
+            status.filter(_.isFile).map(_.getPath.toUri.toString)
+          }.getOrElse(Array(path))
+      }
+    }.mkString(",")
+  }
 }
 
 /** Provides utility functions to be used inside SparkSubmit. */
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 7800d3d624..fd1521193f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -520,7 +520,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
         |                              (Default: client).
         |  --class CLASS_NAME          Your application's main class (for Java / Scala apps).
         |  --name NAME                 A name of your application.
-        |  --jars JARS                 Comma-separated list of local jars to include on the driver
+        |  --jars JARS                 Comma-separated list of jars to include on the driver
         |                              and executor classpaths.
         |  --packages                  Comma-separated list of maven coordinates of jars to include
         |                              on the driver and executor classpaths. Will search the local
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index b089357e7b..97357cdbb6 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -20,12 +20,14 @@ package org.apache.spark.deploy
 import java.io._
 import java.net.URI
 import java.nio.charset.StandardCharsets
+import java.nio.file.Files
 
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 
 import com.google.common.io.ByteStreams
-import org.apache.commons.io.{FilenameUtils, FileUtils}
+import org.apache.commons.io.FileUtils
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 import org.scalatest.{BeforeAndAfterEach, Matchers}
@@ -42,7 +44,6 @@ import org.apache.spark.TestUtils.JavaSourceFromString
 import org.apache.spark.scheduler.EventLoggingListener
 import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils}
 
-
 trait TestPrematureExit {
   suite: SparkFunSuite =>
 
@@ -726,6 +727,47 @@ class SparkSubmitSuite
     Utils.unionFileLists(None, Option("/tmp/a.jar")) should be (Set("/tmp/a.jar"))
     Utils.unionFileLists(Option("/tmp/a.jar"), None) should be (Set("/tmp/a.jar"))
   }
+
+  test("support glob path") {
+    val tmpJarDir = Utils.createTempDir()
+    val jar1 = TestUtils.createJarWithFiles(Map("test.resource" -> "1"), tmpJarDir)
+    val jar2 = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpJarDir)
+
+    val tmpFileDir = Utils.createTempDir()
+    val file1 = File.createTempFile("tmpFile1", "", tmpFileDir)
+    val file2 = File.createTempFile("tmpFile2", "", tmpFileDir)
+
+    val tmpPyFileDir = Utils.createTempDir()
+    val pyFile1 = File.createTempFile("tmpPy1", ".py", tmpPyFileDir)
+    val pyFile2 = File.createTempFile("tmpPy2", ".egg", tmpPyFileDir)
+
+    val tmpArchiveDir = Utils.createTempDir()
+    val archive1 = File.createTempFile("archive1", ".zip", tmpArchiveDir)
+    val archive2 = File.createTempFile("archive2", ".zip", tmpArchiveDir)
+
+    val args = Seq(
+      "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"),
+      "--name", "testApp",
+      "--master", "yarn",
+      "--deploy-mode", "client",
+      "--jars", s"${tmpJarDir.getAbsolutePath}/*.jar",
+      "--files", s"${tmpFileDir.getAbsolutePath}/tmpFile*",
+      "--py-files", s"${tmpPyFileDir.getAbsolutePath}/tmpPy*",
+      "--archives", s"${tmpArchiveDir.getAbsolutePath}/*.zip",
+      jar2.toString)
+
+    val appArgs = new SparkSubmitArguments(args)
+    val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3
+    sysProps("spark.yarn.dist.jars").split(",").toSet should be
+      (Set(jar1.toURI.toString, jar2.toURI.toString))
+    sysProps("spark.yarn.dist.files").split(",").toSet should be
+      (Set(file1.toURI.toString, file2.toURI.toString))
+    sysProps("spark.submit.pyFiles").split(",").toSet should be
+      (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath))
+    sysProps("spark.yarn.dist.archives").split(",").toSet should be
+      (Set(archive1.toURI.toString, archive2.toURI.toString))
+  }
+
   // scalastyle:on println
 
   private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = {
@@ -738,7 +780,7 @@ class SparkSubmitSuite
     assert(outputUri.getScheme === "file")
 
     // The path and filename are preserved.
-    assert(outputUri.getPath.endsWith(sourceUri.getPath))
+    assert(outputUri.getPath.endsWith(new Path(sourceUri).getName))
     assert(FileUtils.readFileToString(new File(outputUri.getPath)) ===
       FileUtils.readFileToString(new File(sourceUri.getPath)))
   }
@@ -752,25 +794,29 @@ class SparkSubmitSuite
 
   test("downloadFile - invalid url") {
     intercept[IOException] {
-      SparkSubmit.downloadFile("abc:/my/file", new Configuration())
+      SparkSubmit.downloadFile(
+        "abc:/my/file", Utils.createTempDir(), mutable.Map.empty, new Configuration())
     }
   }
 
   test("downloadFile - file doesn't exist") {
     val hadoopConf = new Configuration()
+    val tmpDir = Utils.createTempDir()
     // Set s3a implementation to local file system for testing.
     hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
     // Disable file system impl cache to make sure the test file system is picked up.
     hadoopConf.set("fs.s3a.impl.disable.cache", "true")
     intercept[FileNotFoundException] {
-      SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf)
+      SparkSubmit.downloadFile("s3a:/no/such/file", tmpDir, mutable.Map.empty, hadoopConf)
     }
   }
 
   test("downloadFile does not download local file") {
     // empty path is considered as local file.
-    assert(SparkSubmit.downloadFile("", new Configuration()) === "")
-    assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file")
+    val tmpDir = Files.createTempDirectory("tmp").toFile
+    assert(SparkSubmit.downloadFile("", tmpDir, mutable.Map.empty, new Configuration()) === "")
+    assert(SparkSubmit.downloadFile("/local/file", tmpDir, mutable.Map.empty,
+      new Configuration()) === "/local/file")
   }
 
   test("download one file to local") {
@@ -779,12 +825,14 @@ class SparkSubmitSuite
     val content = "hello, world"
     FileUtils.write(jarFile, content)
     val hadoopConf = new Configuration()
+    val tmpDir = Files.createTempDirectory("tmp").toFile
     // Set s3a implementation to local file system for testing.
     hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
     // Disable file system impl cache to make sure the test file system is picked up.
     hadoopConf.set("fs.s3a.impl.disable.cache", "true")
     val sourcePath = s"s3a://${jarFile.getAbsolutePath}"
-    val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf)
+    val outputPath =
+      SparkSubmit.downloadFile(sourcePath, tmpDir, mutable.Map.empty, hadoopConf)
     checkDownloadedFile(sourcePath, outputPath)
     deleteTempOutputFile(outputPath)
   }
@@ -795,12 +843,14 @@ class SparkSubmitSuite
     val content = "hello, world"
     FileUtils.write(jarFile, content)
     val hadoopConf = new Configuration()
+    val tmpDir = Files.createTempDirectory("tmp").toFile
     // Set s3a implementation to local file system for testing.
     hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
     // Disable file system impl cache to make sure the test file system is picked up.
     hadoopConf.set("fs.s3a.impl.disable.cache", "true")
     val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}")
-    val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",")
+    val outputPaths = SparkSubmit.downloadFileList(
+      sourcePaths.mkString(","), tmpDir, mutable.Map.empty, hadoopConf).split(",")
 
     assert(outputPaths.length === sourcePaths.length)
     sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) =>
diff --git a/docs/configuration.md b/docs/configuration.md
index c785a664c6..7dc23e441a 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -422,21 +422,21 @@ Apart from these, the following properties are also available, and may be useful
   <td><code>spark.files</code></td>
   <td></td>
   <td>
-    Comma-separated list of files to be placed in the working directory of each executor.
+    Comma-separated list of files to be placed in the working directory of each executor. Globs are allowed.
   </td>
 </tr>
 <tr>
   <td><code>spark.submit.pyFiles</code></td>
   <td></td>
   <td>
-    Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps.
+    Comma-separated list of .zip, .egg, or .py files to place on the PYTHONPATH for Python apps. Globs are allowed.
   </td>
 </tr>
 <tr>
   <td><code>spark.jars</code></td>
   <td></td>
   <td>
-    Comma-separated list of local jars to include on the driver and executor classpaths.
+    Comma-separated list of jars to include on the driver and executor classpaths. Globs are allowed.
   </td>
 </tr>
 <tr>
-- 
GitLab