From d451b7f43d559aa1efd7ac3d1cbec5249f3a7a24 Mon Sep 17 00:00:00 2001 From: hyukjinkwon <gurwls223@gmail.com> Date: Fri, 7 Jul 2017 12:24:03 +0800 Subject: [PATCH] [SPARK-21326][SPARK-21066][ML] Use TextFileFormat in LibSVMFileFormat and allow multiple input paths for determining numFeatures ## What changes were proposed in this pull request? This is related with [SPARK-19918](https://issues.apache.org/jira/browse/SPARK-19918) and [SPARK-18362](https://issues.apache.org/jira/browse/SPARK-18362). This PR proposes to use `TextFileFormat` and allow multiple input paths (but with a warning) when determining the number of features in LibSVM data source via an extra scan. There are three points here: - The main advantage of this change should be to remove file-listing bottlenecks in driver side. - Another advantage is ones from using `FileScanRDD`. For example, I guess we can use `spark.sql.files.ignoreCorruptFiles` option when determining the number of features. - We can unify the schema inference code path in text based data sources. This is also a preparation for [SPARK-21289](https://issues.apache.org/jira/browse/SPARK-21289). ## How was this patch tested? Unit tests in `LibSVMRelationSuite`. Closes #18288 Author: hyukjinkwon <gurwls223@gmail.com> Closes #18556 from HyukjinKwon/libsvm-schema. --- .../ml/source/libsvm/LibSVMRelation.scala | 26 +++++++++---------- .../org/apache/spark/mllib/util/MLUtils.scala | 25 ++++++++++++++++-- .../source/libsvm/LibSVMRelationSuite.scala | 17 +++++++++--- 3 files changed, 49 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index f68847a664..dec118330a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -23,6 +23,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.spark.internal.Logging import org.apache.spark.TaskContext import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{Vectors, VectorUDT} @@ -66,7 +67,10 @@ private[libsvm] class LibSVMOutputWriter( /** @see [[LibSVMDataSource]] for public documentation. */ // If this is moved or renamed, please update DataSource's backwardCompatibilityMap. -private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSourceRegister { +private[libsvm] class LibSVMFileFormat + extends TextBasedFileFormat + with DataSourceRegister + with Logging { override def shortName(): String = "libsvm" @@ -89,18 +93,14 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour files: Seq[FileStatus]): Option[StructType] = { val libSVMOptions = new LibSVMOptions(options) val numFeatures: Int = libSVMOptions.numFeatures.getOrElse { - // Infers number of features if the user doesn't specify (a valid) one. - val dataFiles = files.filterNot(_.getPath.getName startsWith "_") - val path = if (dataFiles.length == 1) { - dataFiles.head.getPath.toUri.toString - } else if (dataFiles.isEmpty) { - throw new IOException("No input path specified for libsvm data") - } else { - throw new IOException("Multiple input paths are not supported for libsvm data.") - } - - val sc = sparkSession.sparkContext - val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) + require(files.nonEmpty, "No input path specified for libsvm data") + logWarning( + "'numFeatures' option not specified, determining the number of features by going " + + "though the input. If you know the number in advance, please specify it via " + + "'numFeatures' option to avoid the extra scan.") + + val paths = files.map(_.getPath.toUri.toString) + val parsed = MLUtils.parseLibSVMFile(sparkSession, paths) MLUtils.computeNumFeatures(parsed) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4fdad05973..14af8b5c73 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -28,8 +28,10 @@ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} -import org.apache.spark.sql.{DataFrame, Dataset} -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.text.TextFileFormat +import org.apache.spark.sql.functions._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.BernoulliCellSampler @@ -102,6 +104,25 @@ object MLUtils extends Logging { .map(parseLibSVMRecord) } + private[spark] def parseLibSVMFile( + sparkSession: SparkSession, paths: Seq[String]): RDD[(Double, Array[Int], Array[Double])] = { + val lines = sparkSession.baseRelationToDataFrame( + DataSource.apply( + sparkSession, + paths = paths, + className = classOf[TextFileFormat].getName + ).resolveRelation(checkFilesExist = false)) + .select("value") + + import lines.sqlContext.implicits._ + + lines.select(trim($"value").as("line")) + .filter(not((length($"line") === 0).or($"line".startsWith("#")))) + .as[String] + .rdd + .map(MLUtils.parseLibSVMRecord) + } + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { val items = line.split(' ') val label = items.head.toDouble diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e164d279f3..a67e49d54e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -35,15 +35,22 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { override def beforeAll(): Unit = { super.beforeAll() - val lines = + val lines0 = """ |1 1:1.0 3:2.0 5:3.0 |0 + """.stripMargin + val lines1 = + """ |0 2:4.0 4:5.0 6:6.0 """.stripMargin val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") - val file = new File(dir, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + val succ = new File(dir, "_SUCCESS") + val file0 = new File(dir, "part-00000") + val file1 = new File(dir, "part-00001") + Files.write("", succ, StandardCharsets.UTF_8) + Files.write(lines0, file0, StandardCharsets.UTF_8) + Files.write(lines1, file1, StandardCharsets.UTF_8) path = dir.toURI.toString } @@ -145,7 +152,9 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("create libsvmTable table without schema and path") { try { - val e = intercept[IOException](spark.sql("CREATE TABLE libsvmTable USING libsvm")) + val e = intercept[IllegalArgumentException] { + spark.sql("CREATE TABLE libsvmTable USING libsvm") + } assert(e.getMessage.contains("No input path specified for libsvm data")) } finally { spark.sql("DROP TABLE IF EXISTS libsvmTable") -- GitLab