diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a787d5a9a94380a7d5365930cbf328c7c614b8e0..1830839aeebb70a7eb3d753d970a6ebc5c11786f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.jdbc._ -import org.apache.spark.sql.execution.datasources.json.InferSchema +import org.apache.spark.sql.execution.datasources.json.JsonInferSchema import org.apache.spark.sql.types.StructType /** @@ -334,7 +334,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { parsedOptions.columnNameOfCorruptRecord .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) val schema = userSpecifiedSchema.getOrElse { - InferSchema.infer( + JsonInferSchema.infer( jsonRDD, columnNameOfCorruptRecord, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala index 38970160d5fb3c1e5a775ecd7cdfd9c97f62dcd2..1d2bf07047a23f3230021bd7a60c6b668674111a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.{Charset, StandardCharsets} -import com.univocity.parsers.csv.CsvParser import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, Text} @@ -28,13 +27,11 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.CompressionCodecs import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat -import org.apache.spark.sql.functions.{length, trim} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -60,64 +57,9 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { val csvOptions = new CSVOptions(options) val paths = files.map(_.getPath.toString) - val lines: Dataset[String] = readText(sparkSession, csvOptions, paths) - val firstLine: String = findFirstLine(csvOptions, lines) - val firstRow = new CsvParser(csvOptions.asParserSettings).parseLine(firstLine) + val lines: Dataset[String] = createBaseDataset(sparkSession, csvOptions, paths) val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, csvOptions, caseSensitive) - - val parsedRdd: RDD[Array[String]] = CSVRelation.univocityTokenizer( - lines, - firstLine = if (csvOptions.headerFlag) firstLine else null, - params = csvOptions) - val schema = if (csvOptions.inferSchemaFlag) { - CSVInferSchema.infer(parsedRdd, header, csvOptions) - } else { - // By default fields are assumed to be StringType - val schemaFields = header.map { fieldName => - StructField(fieldName, StringType, nullable = true) - } - StructType(schemaFields) - } - Some(schema) - } - - /** - * Generates a header from the given row which is null-safe and duplicate-safe. - */ - private def makeSafeHeader( - row: Array[String], - options: CSVOptions, - caseSensitive: Boolean): Array[String] = { - if (options.headerFlag) { - val duplicates = { - val headerNames = row.filter(_ != null) - .map(name => if (caseSensitive) name else name.toLowerCase) - headerNames.diff(headerNames.distinct).distinct - } - - row.zipWithIndex.map { case (value, index) => - if (value == null || value.isEmpty || value == options.nullValue) { - // When there are empty strings or the values set in `nullValue`, put the - // index as the suffix. - s"_c$index" - } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { - // When there are case-insensitive duplicates, put the index as the suffix. - s"$value$index" - } else if (duplicates.contains(value)) { - // When there are duplicates, put the index as the suffix. - s"$value$index" - } else { - value - } - } - } else { - row.zipWithIndex.map { case (_, index) => - // Uses default column names, "_c#" where # is its position of fields - // when header option is disabled. - s"_c$index" - } - } + Some(CSVInferSchema.infer(lines, caseSensitive, csvOptions)) } override def prepareWrite( @@ -125,7 +67,7 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { job: Job, options: Map[String, String], dataSchema: StructType): OutputWriterFactory = { - verifySchema(dataSchema) + CSVUtils.verifySchema(dataSchema) val conf = job.getConfiguration val csvOptions = new CSVOptions(options) csvOptions.compressionCodec.foreach { codec => @@ -155,13 +97,12 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { options: Map[String, String], hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = { val csvOptions = new CSVOptions(options) - val commentPrefix = csvOptions.comment.toString val broadcastedHadoopConf = sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) (file: PartitionedFile) => { - val lineIterator = { + val lines = { val conf = broadcastedHadoopConf.value.value val linesReader = new HadoopFileLinesReader(file, conf) Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => linesReader.close())) @@ -170,32 +111,21 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { } } - // Consumes the header in the iterator. - CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) - - val filteredIter = lineIterator.filter { line => - line.trim.nonEmpty && !line.startsWith(commentPrefix) + val linesWithoutHeader = if (csvOptions.headerFlag && file.start == 0) { + // Note that if there are only comments in the first block, the header would probably + // be not dropped. + CSVUtils.dropHeaderLine(lines, csvOptions) + } else { + lines } + val filteredLines = CSVUtils.filterCommentAndEmpty(linesWithoutHeader, csvOptions) val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions) - filteredIter.flatMap(parser.parse) - } - } - - /** - * Returns the first line of the first non-empty file in path - */ - private def findFirstLine(options: CSVOptions, lines: Dataset[String]): String = { - import lines.sqlContext.implicits._ - val nonEmptyLines = lines.filter(length(trim($"value")) > 0) - if (options.isCommentSet) { - nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)).first() - } else { - nonEmptyLines.first() + filteredLines.flatMap(parser.parse) } } - private def readText( + private def createBaseDataset( sparkSession: SparkSession, options: CSVOptions, inputPaths: Seq[String]): Dataset[String] = { @@ -215,22 +145,6 @@ class CSVFileFormat extends TextBasedFileFormat with DataSourceRegister { sparkSession.createDataset(rdd)(Encoders.STRING) } } - - private def verifySchema(schema: StructType): Unit = { - def verifyType(dataType: DataType): Unit = dataType match { - case ByteType | ShortType | IntegerType | LongType | FloatType | - DoubleType | BooleanType | _: DecimalType | TimestampType | - DateType | StringType => - - case udt: UserDefinedType[_] => verifyType(udt.sqlType) - - case _ => - throw new UnsupportedOperationException( - s"CSV data source does not support ${dataType.simpleString} data type.") - } - - schema.foreach(field => verifyType(field.dataType)) - } } private[csv] class CsvOutputWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 065bf5357436685463b1e4625215fcb8e68e1821..485b186c7cf0c24e9346da8b0858fdd17fdd910c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -18,17 +18,15 @@ package org.apache.spark.sql.execution.datasources.csv import java.math.BigDecimal -import java.text.NumberFormat -import java.util.Locale import scala.util.control.Exception._ -import scala.util.Try -import org.apache.spark.rdd.RDD +import com.univocity.parsers.csv.CsvParser + import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.Dataset import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String private[csv] object CSVInferSchema { @@ -39,22 +37,76 @@ private[csv] object CSVInferSchema { * 3. Replace any null types with string type */ def infer( - tokenRdd: RDD[Array[String]], - header: Array[String], + csv: Dataset[String], + caseSensitive: Boolean, options: CSVOptions): StructType = { - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val rootTypes: Array[DataType] = - tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) - - val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => - val dType = rootType match { - case _: NullType => StringType - case other => other + val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, options).first() + val firstRow = new CsvParser(options.asParserSettings).parseLine(firstLine) + val header = makeSafeHeader(firstRow, caseSensitive, options) + + val fields = if (options.inferSchemaFlag) { + val tokenRdd = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, options) + val linesWithoutHeader = CSVUtils.filterHeaderLine(filteredLines, firstLine, options) + val parser = new CsvParser(options.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val rootTypes: Array[DataType] = + tokenRdd.aggregate(startType)(inferRowType(options), mergeRowTypes) + + header.zip(rootTypes).map { case (thisHeader, rootType) => + val dType = rootType match { + case _: NullType => StringType + case other => other + } + StructField(thisHeader, dType, nullable = true) } - StructField(thisHeader, dType, nullable = true) + } else { + // By default fields are assumed to be StringType + header.map(fieldName => StructField(fieldName, StringType, nullable = true)) } - StructType(structFields) + StructType(fields) + } + + /** + * Generates a header from the given row which is null-safe and duplicate-safe. + */ + private def makeSafeHeader( + row: Array[String], + caseSensitive: Boolean, + options: CSVOptions): Array[String] = { + if (options.headerFlag) { + val duplicates = { + val headerNames = row.filter(_ != null) + .map(name => if (caseSensitive) name else name.toLowerCase) + headerNames.diff(headerNames.distinct).distinct + } + + row.zipWithIndex.map { case (value, index) => + if (value == null || value.isEmpty || value == options.nullValue) { + // When there are empty strings or the values set in `nullValue`, put the + // index as the suffix. + s"_c$index" + } else if (!caseSensitive && duplicates.contains(value.toLowerCase)) { + // When there are case-insensitive duplicates, put the index as the suffix. + s"$value$index" + } else if (duplicates.contains(value)) { + // When there are duplicates, put the index as the suffix. + s"$value$index" + } else { + value + } + } + } else { + row.zipWithIndex.map { case (_, index) => + // Uses default column names, "_c#" where # is its position of fields + // when header option is disabled. + s"_c$index" + } + } } private def inferRowType(options: CSVOptions) @@ -215,32 +267,3 @@ private[csv] object CSVInferSchema { case _ => None } } - -private[csv] object CSVTypeCast { - /** - * Helper method that converts string representation of a character to actual character. - * It handles some Java escaped strings and throws exception if given string is longer than one - * character. - */ - @throws[IllegalArgumentException] - def toChar(str: String): Char = { - if (str.charAt(0) == '\\') { - str.charAt(1) - match { - case 't' => '\t' - case 'r' => '\r' - case 'b' => '\b' - case 'f' => '\f' - case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options - case '\'' => '\'' - case 'u' if str == """\u0000""" => '\u0000' - case _ => - throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") - } - } else if (str.length == 1) { - str.charAt(0) - } else { - throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala index 140ce23958dc0dd30374499fa51b18725c8add57..af456c8d7114fd529f22fa2a199eb1987032ba65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala @@ -69,7 +69,7 @@ private[csv] class CSVOptions(@transient private val parameters: CaseInsensitive } } - val delimiter = CSVTypeCast.toChar( + val delimiter = CSVUtils.toChar( parameters.getOrElse("sep", parameters.getOrElse("delimiter", ","))) private val parseMode = parameters.getOrElse("mode", "PERMISSIVE") val charset = parameters.getOrElse("encoding", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala deleted file mode 100644 index 19058c23abe750c1ed939265e711e545ff72006e..0000000000000000000000000000000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.csv - -import com.univocity.parsers.csv.CsvParser - -import org.apache.spark.internal.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.PartitionedFile - -object CSVRelation extends Logging { - - def univocityTokenizer( - file: Dataset[String], - firstLine: String, - params: CSVOptions): RDD[Array[String]] = { - // If header is set, make sure firstLine is materialized before sending to executors. - val commentPrefix = params.comment.toString - file.rdd.mapPartitions { iter => - val parser = new CsvParser(params.asParserSettings) - val filteredIter = iter.filter { line => - line.trim.nonEmpty && !line.startsWith(commentPrefix) - } - if (params.headerFlag) { - filteredIter.filterNot(_ == firstLine).map { item => - parser.parseLine(item) - } - } else { - filteredIter.map { item => - parser.parseLine(item) - } - } - } - } - - // Skips the header line of each file if the `header` option is set to true. - def dropHeaderLine( - file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = { - // TODO What if the first partitioned file consists of only comments and empty lines? - if (csvOptions.headerFlag && file.start == 0) { - val nonEmptyLines = if (csvOptions.isCommentSet) { - val commentPrefix = csvOptions.comment.toString - lines.dropWhile { line => - line.trim.isEmpty || line.trim.startsWith(commentPrefix) - } - } else { - lines.dropWhile(_.trim.isEmpty) - } - - if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala new file mode 100644 index 0000000000000000000000000000000000000000..72b053d2092ca7b0e1ef042bcf56e6d5cb4e1ba4 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtils.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +object CSVUtils { + /** + * Filter ignorable rows for CSV dataset (lines empty and starting with `comment`). + * This is currently being used in CSV schema inference. + */ + def filterCommentAndEmpty(lines: Dataset[String], options: CSVOptions): Dataset[String] = { + // Note that this was separately made by SPARK-18362. Logically, this should be the same + // with the one below, `filterCommentAndEmpty` but execution path is different. One of them + // might have to be removed in the near future if possible. + import lines.sqlContext.implicits._ + val nonEmptyLines = lines.filter(length(trim($"value")) > 0) + if (options.isCommentSet) { + nonEmptyLines.filter(!$"value".startsWith(options.comment.toString)) + } else { + nonEmptyLines + } + } + + /** + * Filter ignorable rows for CSV iterator (lines empty and starting with `comment`). + * This is currently being used in CSV reading path and CSV schema inference. + */ + def filterCommentAndEmpty(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + iter.filter { line => + line.trim.nonEmpty && !line.startsWith(options.comment.toString) + } + } + + /** + * Skip the given first line so that only data can remain in a dataset. + * This is similar with `dropHeaderLine` below and currently being used in CSV schema inference. + */ + def filterHeaderLine( + iter: Iterator[String], + firstLine: String, + options: CSVOptions): Iterator[String] = { + // Note that unlike actual CSV reading path, it simply filters the given first line. Therefore, + // this skips the line same with the header if exists. One of them might have to be removed + // in the near future if possible. + if (options.headerFlag) { + iter.filterNot(_ == firstLine) + } else { + iter + } + } + + /** + * Drop header line so that only data can remain. + * This is similar with `filterHeaderLine` above and currently being used in CSV reading path. + */ + def dropHeaderLine(iter: Iterator[String], options: CSVOptions): Iterator[String] = { + val nonEmptyLines = if (options.isCommentSet) { + val commentPrefix = options.comment.toString + iter.dropWhile { line => + line.trim.isEmpty || line.trim.startsWith(commentPrefix) + } + } else { + iter.dropWhile(_.trim.isEmpty) + } + + if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) + iter + } + + /** + * Helper method that converts string representation of a character to actual character. + * It handles some Java escaped strings and throws exception if given string is longer than one + * character. + */ + @throws[IllegalArgumentException] + def toChar(str: String): Char = { + if (str.charAt(0) == '\\') { + str.charAt(1) + match { + case 't' => '\t' + case 'r' => '\r' + case 'b' => '\b' + case 'f' => '\f' + case '\"' => '\"' // In case user changes quote char and uses \" as delimiter in options + case '\'' => '\'' + case 'u' if str == """\u0000""" => '\u0000' + case _ => + throw new IllegalArgumentException(s"Unsupported special character for delimiter: $str") + } + } else if (str.length == 1) { + str.charAt(0) + } else { + throw new IllegalArgumentException(s"Delimiter cannot be more than one character: $str") + } + } + + /** + * Verify if the schema is supported in CSV datasource. + */ + def verifySchema(schema: StructType): Unit = { + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") + } + + schema.foreach(field => verifyType(field.dataType)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index be1f94dbad9124e961db9c690f7e06e31aaafe30..98ab9d28500032edd26628a177782dacea2021ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -51,7 +51,7 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord .getOrElse(sparkSession.sessionState.conf.columnNameOfCorruptRecord) - val jsonSchema = InferSchema.infer( + val jsonSchema = JsonInferSchema.infer( createBaseRdd(sparkSession, files), columnNameOfCorruptRecord, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala index 330d04de666c8b680656482edf236a991e96c51a..f51c18d46f45d9ae95b8f406abe3fe8b292b01b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.json.JSONOptions import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -private[sql] object InferSchema { +private[sql] object JsonInferSchema { /** * Infer the type of a collection of json records in three stages: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..221e44ce2cff6c6f0a0777d96ac419ff5ac73216 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVUtilsSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.SparkFunSuite + +class CSVUtilsSuite extends SparkFunSuite { + test("Can parse escaped characters") { + assert(CSVUtils.toChar("""\t""") === '\t') + assert(CSVUtils.toChar("""\r""") === '\r') + assert(CSVUtils.toChar("""\b""") === '\b') + assert(CSVUtils.toChar("""\f""") === '\f') + assert(CSVUtils.toChar("""\"""") === '\"') + assert(CSVUtils.toChar("""\'""") === '\'') + assert(CSVUtils.toChar("""\u0000""") === '\u0000') + } + + test("Does not accept delimiter larger than one character") { + val exception = intercept[IllegalArgumentException]{ + CSVUtils.toChar("ab") + } + assert(exception.getMessage.contains("cannot be more than one character")) + } + + test("Throws exception for unsupported escaped characters") { + val exception = intercept[IllegalArgumentException]{ + CSVUtils.toChar("""\1""") + } + assert(exception.getMessage.contains("Unsupported special character for delimiter")) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala index 2ca6308852a7e79e1010da1bc5ff87dde5c62b3f..62dae08861df1f8d5c14732a633bfcb5a28593d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParserSuite.scala @@ -43,30 +43,6 @@ class UnivocityParserSuite extends SparkFunSuite { } } - test("Can parse escaped characters") { - assert(CSVTypeCast.toChar("""\t""") === '\t') - assert(CSVTypeCast.toChar("""\r""") === '\r') - assert(CSVTypeCast.toChar("""\b""") === '\b') - assert(CSVTypeCast.toChar("""\f""") === '\f') - assert(CSVTypeCast.toChar("""\"""") === '\"') - assert(CSVTypeCast.toChar("""\'""") === '\'') - assert(CSVTypeCast.toChar("""\u0000""") === '\u0000') - } - - test("Does not accept delimiter larger than one character") { - val exception = intercept[IllegalArgumentException]{ - CSVTypeCast.toChar("ab") - } - assert(exception.getMessage.contains("cannot be more than one character")) - } - - test("Throws exception for unsupported escaped characters") { - val exception = intercept[IllegalArgumentException]{ - CSVTypeCast.toChar("""\1""") - } - assert(exception.getMessage.contains("Unsupported special character for delimiter")) - } - test("Nullable types are handled") { val types = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, BooleanType, DecimalType.DoubleDecimal, TimestampType, DateType, StringType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 161a409d83ad477a5964b7d3c13283a166e7861c..156fd965b468389909782b13e80ebc11e4becb30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType +import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -1366,7 +1366,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { test("SPARK-6245 JsonRDD.inferSchema on empty RDD") { // This is really a test that it doesn't throw an exception - val emptySchema = InferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String])) + val emptySchema = JsonInferSchema.infer(empty, "", new JSONOptions(Map.empty[String, String])) assert(StructType(Seq()) === emptySchema) } @@ -1390,7 +1390,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { } test("SPARK-8093 Erase empty structs") { - val emptySchema = InferSchema.infer( + val emptySchema = JsonInferSchema.infer( emptyRecords, "", new JSONOptions(Map.empty[String, String])) assert(StructType(Seq()) === emptySchema) }