Skip to content
Snippets Groups Projects
Commit e33aaa2a authored by gatorsmile's avatar gatorsmile Committed by Wenchen Fan
Browse files

[SPARK-19397][SQL] Make option names of LIBSVM and TEXT case insensitive

### What changes were proposed in this pull request?
Prior to Spark 2.1, the option names are case sensitive for all the formats. Since Spark 2.1, the option key names become case insensitive except the format `Text` and `LibSVM `. This PR is to fix these issues.

Also, add a check to know whether the input option vector type is legal for `LibSVM`.

### How was this patch tested?
Added test cases

Author: gatorsmile <gatorsmile@gmail.com>

Closes #16737 from gatorsmile/libSVMTextOptions.
parent 8df44440
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.source.libsvm
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
/**
* Options for the LibSVM data source.
*/
private[libsvm] class LibSVMOptions(@transient private val parameters: CaseInsensitiveMap)
extends Serializable {
import LibSVMOptions._
def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
/**
* Number of features. If unspecified or nonpositive, the number of features will be determined
* automatically at the cost of one additional pass.
*/
val numFeatures = parameters.get(NUM_FEATURES).map(_.toInt).filter(_ > 0)
val isSparse = parameters.getOrElse(VECTOR_TYPE, SPARSE_VECTOR_TYPE) match {
case SPARSE_VECTOR_TYPE => true
case DENSE_VECTOR_TYPE => false
case o => throw new IllegalArgumentException(s"Invalid value `$o` for parameter " +
s"`$VECTOR_TYPE`. Expected types are `sparse` and `dense`.")
}
}
private[libsvm] object LibSVMOptions {
val NUM_FEATURES = "numFeatures"
val VECTOR_TYPE = "vectorType"
val DENSE_VECTOR_TYPE = "dense"
val SPARSE_VECTOR_TYPE = "sparse"
}
......@@ -77,7 +77,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
dataSchema.size != 2 ||
!dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
!dataSchema(1).dataType.sameType(new VectorUDT()) ||
!(dataSchema(1).metadata.getLong("numFeatures").toInt > 0)
!(dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
) {
throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema")
}
......@@ -87,7 +87,8 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
sparkSession: SparkSession,
options: Map[String, String],
files: Seq[FileStatus]): Option[StructType] = {
val numFeatures: Int = options.get("numFeatures").map(_.toInt).filter(_ > 0).getOrElse {
val libSVMOptions = new LibSVMOptions(options)
val numFeatures: Int = libSVMOptions.numFeatures.getOrElse {
// Infers number of features if the user doesn't specify (a valid) one.
val dataFiles = files.filterNot(_.getPath.getName startsWith "_")
val path = if (dataFiles.length == 1) {
......@@ -104,7 +105,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
}
val featuresMetadata = new MetadataBuilder()
.putLong("numFeatures", numFeatures)
.putLong(LibSVMOptions.NUM_FEATURES, numFeatures)
.build()
Some(
......@@ -142,10 +143,11 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
verifySchema(dataSchema)
val numFeatures = dataSchema("features").metadata.getLong("numFeatures").toInt
val numFeatures = dataSchema("features").metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt
assert(numFeatures > 0)
val sparse = options.getOrElse("vectorType", "sparse") == "sparse"
val libSVMOptions = new LibSVMOptions(options)
val isSparse = libSVMOptions.isSparse
val broadcastedHadoopConf =
sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf))
......@@ -173,7 +175,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
val requiredColumns = GenerateUnsafeProjection.generate(requiredOutput, fullOutput)
points.map { pt =>
val features = if (sparse) pt.features.toSparse else pt.features.toDense
val features = if (isSparse) pt.features.toSparse else pt.features.toDense
requiredColumns(converter.toRow(Row(pt.label, features)))
}
}
......
......@@ -77,6 +77,14 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(v == Vectors.dense(1.0, 0.0, 2.0, 0.0, 3.0, 0.0))
}
test("illegal vector types") {
val e = intercept[IllegalArgumentException] {
spark.read.format("libsvm").options(Map("VectorType" -> "sparser")).load(path)
}.getMessage
assert(e.contains("Invalid value `sparser` for parameter `vectorType`. Expected " +
"types are `sparse` and `dense`."))
}
test("select a vector with specifying the longer dimension") {
val df = spark.read.option("numFeatures", "100").format("libsvm")
.load(path)
......@@ -85,6 +93,12 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(v == Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}
test("case insensitive option") {
val df = spark.read.option("NuMfEaTuReS", "100").format("libsvm").load(path)
assert(df.first().getAs[SparseVector](1) ==
Vectors.sparse(100, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
}
test("write libsvm data and read it again") {
val df = spark.read.format("libsvm").load(path)
val tempDir2 = new File(tempDir, "read_write_test")
......
......@@ -65,9 +65,10 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
dataSchema: StructType): OutputWriterFactory = {
verifySchema(dataSchema)
val textOptions = new TextOptions(options)
val conf = job.getConfiguration
val compressionCodec = options.get("compression").map(CompressionCodecs.getCodecClassName)
compressionCodec.foreach { codec =>
textOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.datasources.text
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CompressionCodecs}
/**
* Options for the Text data source.
*/
private[text] class TextOptions(@transient private val parameters: CaseInsensitiveMap)
extends Serializable {
import TextOptions._
def this(parameters: Map[String, String]) = this(new CaseInsensitiveMap(parameters))
/**
* Compression codec to use.
*/
val compressionCodec = parameters.get(COMPRESSION).map(CompressionCodecs.getCodecClassName)
}
private[text] object TextOptions {
val COMPRESSION = "compression"
}
......@@ -115,8 +115,7 @@ class TextSuite extends QueryTest with SharedSQLContext {
)
withTempDir { dir =>
val testDf = spark.read.text(testFile)
val tempDir = Utils.createTempDir()
val tempDirPath = tempDir.getAbsolutePath
val tempDirPath = dir.getAbsolutePath
testDf.write.option("compression", "none")
.options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath)
val compressedFiles = new File(tempDirPath).listFiles()
......@@ -125,6 +124,25 @@ class TextSuite extends QueryTest with SharedSQLContext {
}
}
test("case insensitive option") {
val extraOptions = Map[String, String](
"mApReDuCe.output.fileoutputformat.compress" -> "true",
"mApReDuCe.output.fileoutputformat.compress.type" -> CompressionType.BLOCK.toString,
"mApReDuCe.map.output.compress" -> "true",
"mApReDuCe.output.fileoutputformat.compress.codec" -> classOf[GzipCodec].getName,
"mApReDuCe.map.output.compress.codec" -> classOf[GzipCodec].getName
)
withTempDir { dir =>
val testDf = spark.read.text(testFile)
val tempDirPath = dir.getAbsolutePath
testDf.write.option("CoMpReSsIoN", "none")
.options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath)
val compressedFiles = new File(tempDirPath).listFiles()
assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz")))
verifyFrame(spark.read.options(extraOptions).text(tempDirPath))
}
}
test("SPARK-14343: select partitioning column") {
withTempPath { dir =>
val path = dir.getCanonicalPath
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment