Skip to content
Snippets Groups Projects
Commit 1b070637 authored by Cheng Lian's avatar Cheng Lian Committed by Xiangrui Meng
Browse files

[SPARK-14295][SPARK-14274][SQL] Implements buildReader() for LibSVM

## What changes were proposed in this pull request?

This PR implements `FileFormat.buildReader()` for the LibSVM data source. Besides that, a new interface method `prepareRead()` is added to `FileFormat`:

```scala
  def prepareRead(
      sqlContext: SQLContext,
      options: Map[String, String],
      files: Seq[FileStatus]): Map[String, String] = options
```

After migrating from `buildInternalScan()` to `buildReader()`, we lost the opportunity to collect necessary global information, since `buildReader()` works in a per-partition manner. For example, LibSVM needs to infer the total number of features if the `numFeatures` data source option is not set. Any necessary collected global information should be returned using the data source options map. By default, this method just returns the original options untouched.

An alternative approach is to absorb `inferSchema()` into `prepareRead()`, since schema inference is also some kind of global information gathering. However, this approach wasn't chosen because schema inference is optional, while `prepareRead()` must be called whenever a `HadoopFsRelation` based data source relation is instantiated.

One unaddressed problem is that, when `numFeatures` is absent, now the input data will be scanned twice. The `buildInternalScan()` code path doesn't need to do this because it caches the raw parsed RDD in memory before computing the total number of features. However, with `FileScanRDD`, the raw parsed RDD is created in a different way (e.g. partitioning) from the final RDD.

## How was this patch tested?

Tested using existing test suites.

Author: Cheng Lian <lian@databricks.com>

Closes #12088 from liancheng/spark-14295-libsvm-build-reader.
parent 96941b12
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,7 @@ package org.apache.spark.ml.source.libsvm ...@@ -19,6 +19,7 @@ package org.apache.spark.ml.source.libsvm
import java.io.IOException import java.io.IOException
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.io.{NullWritable, Text}
import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext}
...@@ -26,12 +27,16 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat ...@@ -26,12 +27,16 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat
import org.apache.spark.annotation.Since import org.apache.spark.annotation.Since
import org.apache.spark.broadcast.Broadcast import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFileLinesReader, PartitionedFile}
import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.SerializableConfiguration
...@@ -110,13 +115,16 @@ class DefaultSource extends FileFormat with DataSourceRegister { ...@@ -110,13 +115,16 @@ class DefaultSource extends FileFormat with DataSourceRegister {
@Since("1.6.0") @Since("1.6.0")
override def shortName(): String = "libsvm" override def shortName(): String = "libsvm"
override def toString: String = "LibSVM"
private def verifySchema(dataSchema: StructType): Unit = { private def verifySchema(dataSchema: StructType): Unit = {
if (dataSchema.size != 2 || if (dataSchema.size != 2 ||
(!dataSchema(0).dataType.sameType(DataTypes.DoubleType) (!dataSchema(0).dataType.sameType(DataTypes.DoubleType)
|| !dataSchema(1).dataType.sameType(new VectorUDT()))) { || !dataSchema(1).dataType.sameType(new VectorUDT()))) {
throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
} }
} }
override def inferSchema( override def inferSchema(
sqlContext: SQLContext, sqlContext: SQLContext,
options: Map[String, String], options: Map[String, String],
...@@ -127,6 +135,32 @@ class DefaultSource extends FileFormat with DataSourceRegister { ...@@ -127,6 +135,32 @@ class DefaultSource extends FileFormat with DataSourceRegister {
StructField("features", new VectorUDT(), nullable = false) :: Nil)) StructField("features", new VectorUDT(), nullable = false) :: Nil))
} }
override def prepareRead(
sqlContext: SQLContext,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = {
def computeNumFeatures(): Int = {
val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
val path = if (dataFiles.length == 1) {
dataFiles.head.getPath.toUri.toString
} else if (dataFiles.isEmpty) {
throw new IOException("No input path specified for libsvm data")
} else {
throw new IOException("Multiple input paths are not supported for libsvm data.")
}
val sc = sqlContext.sparkContext
val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism)
MLUtils.computeNumFeatures(parsed)
}
val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse {
computeNumFeatures()
}
new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString))
}
override def prepareWrite( override def prepareWrite(
sqlContext: SQLContext, sqlContext: SQLContext,
job: Job, job: Job,
...@@ -158,7 +192,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { ...@@ -158,7 +192,7 @@ class DefaultSource extends FileFormat with DataSourceRegister {
verifySchema(dataSchema) verifySchema(dataSchema)
val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_")
val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString val path = if (dataFiles.length == 1) dataFiles.head.getPath.toUri.toString
else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data") else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data")
else throw new IOException("Multiple input paths are not supported for libsvm data.") else throw new IOException("Multiple input paths are not supported for libsvm data.")
...@@ -176,4 +210,51 @@ class DefaultSource extends FileFormat with DataSourceRegister { ...@@ -176,4 +210,51 @@ class DefaultSource extends FileFormat with DataSourceRegister {
externalRows.map(converter.toRow) externalRows.map(converter.toRow)
} }
} }
override def buildReader(
sqlContext: SQLContext,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = {
val numFeatures = options("numFeatures").toInt
assert(numFeatures > 0)
val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
val broadcastedConf = sqlContext.sparkContext.broadcast(
new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration))
)
(file: PartitionedFile) => {
val points =
new HadoopFileLinesReader(file, broadcastedConf.value.value)
.map(_.toString.trim)
.filterNot(line => line.isEmpty || line.startsWith("#"))
.map { line =>
val (label, indices, values) = MLUtils.parseLibSVMRecord(line)
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
}
val converter = RowEncoder(requiredSchema)
val unsafeRowIterator = points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
converter.toRow(Row(pt.label, features))
}
def toAttribute(f: StructField): AttributeReference =
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
// Appends partition values
val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute)
val joinedRow = new JoinedRow()
val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput)
unsafeRowIterator.map { dataRow =>
appendPartitionColumns(joinedRow(dataRow, file.partitionValues))
}
}
}
} }
...@@ -67,42 +67,14 @@ object MLUtils { ...@@ -67,42 +67,14 @@ object MLUtils {
path: String, path: String,
numFeatures: Int, numFeatures: Int,
minPartitions: Int): RDD[LabeledPoint] = { minPartitions: Int): RDD[LabeledPoint] = {
val parsed = sc.textFile(path, minPartitions) val parsed = parseLibSVMFile(sc, path, minPartitions)
.map(_.trim)
.filter(line => !(line.isEmpty || line.startsWith("#")))
.map { line =>
val items = line.split(' ')
val label = items.head.toDouble
val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
// check if indices are one-based and in ascending order
var previous = -1
var i = 0
val indicesLength = indices.length
while (i < indicesLength) {
val current = indices(i)
require(current > previous, s"indices should be one-based and in ascending order;"
+ " found current=$current, previous=$previous; line=\"$line\"")
previous = current
i += 1
}
(label, indices.toArray, values.toArray)
}
// Determine number of features. // Determine number of features.
val d = if (numFeatures > 0) { val d = if (numFeatures > 0) {
numFeatures numFeatures
} else { } else {
parsed.persist(StorageLevel.MEMORY_ONLY) parsed.persist(StorageLevel.MEMORY_ONLY)
parsed.map { case (label, indices, values) => computeNumFeatures(parsed)
indices.lastOption.getOrElse(0)
}.reduce(math.max) + 1
} }
parsed.map { case (label, indices, values) => parsed.map { case (label, indices, values) =>
...@@ -110,6 +82,47 @@ object MLUtils { ...@@ -110,6 +82,47 @@ object MLUtils {
} }
} }
private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = {
rdd.map { case (label, indices, values) =>
indices.lastOption.getOrElse(0)
}.reduce(math.max) + 1
}
private[spark] def parseLibSVMFile(
sc: SparkContext,
path: String,
minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = {
sc.textFile(path, minPartitions)
.map(_.trim)
.filter(line => !(line.isEmpty || line.startsWith("#")))
.map(parseLibSVMRecord)
}
private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = {
val items = line.split(' ')
val label = items.head.toDouble
val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
val value = indexAndValue(1).toDouble
(index, value)
}.unzip
// check if indices are one-based and in ascending order
var previous = -1
var i = 0
val indicesLength = indices.length
while (i < indicesLength) {
val current = indices(i)
require(current > previous, s"indices should be one-based and in ascending order;"
+ " found current=$current, previous=$previous; line=\"$line\"")
previous = current
i += 1
}
(label, indices, values)
}
/** /**
* Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of
* partitions. * partitions.
......
...@@ -299,6 +299,9 @@ case class DataSource( ...@@ -299,6 +299,9 @@ case class DataSource(
"It must be specified manually") "It must be specified manually")
} }
val enrichedOptions =
format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles())
HadoopFsRelation( HadoopFsRelation(
sqlContext, sqlContext,
fileCatalog, fileCatalog,
...@@ -306,7 +309,7 @@ case class DataSource( ...@@ -306,7 +309,7 @@ case class DataSource(
dataSchema = dataSchema.asNullable, dataSchema = dataSchema.asNullable,
bucketSpec = bucketSpec, bucketSpec = bucketSpec,
format, format,
options) enrichedOptions)
case _ => case _ =>
throw new AnalysisException( throw new AnalysisException(
......
...@@ -59,6 +59,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { ...@@ -59,6 +59,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
if (files.fileFormat.toString == "TestFileFormat" || if (files.fileFormat.toString == "TestFileFormat" ||
files.fileFormat.isInstanceOf[parquet.DefaultSource] || files.fileFormat.isInstanceOf[parquet.DefaultSource] ||
files.fileFormat.toString == "ORC" || files.fileFormat.toString == "ORC" ||
files.fileFormat.toString == "LibSVM" ||
files.fileFormat.isInstanceOf[csv.DefaultSource] || files.fileFormat.isInstanceOf[csv.DefaultSource] ||
files.fileFormat.isInstanceOf[text.DefaultSource] || files.fileFormat.isInstanceOf[text.DefaultSource] ||
files.fileFormat.isInstanceOf[json.DefaultSource]) && files.fileFormat.isInstanceOf[json.DefaultSource]) &&
......
...@@ -438,6 +438,15 @@ trait FileFormat { ...@@ -438,6 +438,15 @@ trait FileFormat {
options: Map[String, String], options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] files: Seq[FileStatus]): Option[StructType]
/**
* Prepares a read job and returns a potentially updated data source option [[Map]]. This method
* can be useful for collecting necessary global information for scanning input data.
*/
def prepareRead(
sqlContext: SQLContext,
options: Map[String, String],
files: Seq[FileStatus]): Map[String, String] = options
/** /**
* Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can
* be put here. For example, user defined output committer can be configured here * be put here. For example, user defined output committer can be configured here
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment