Skip to content
Snippets Groups Projects
Commit dd3b5455 authored by Rahul Tanwani's avatar Rahul Tanwani Committed by Reynold Xin
Browse files

[SPARK-13309][SQL] Fix type inference issue with CSV data

Fix type inference issue for sparse CSV data - https://issues.apache.org/jira/browse/SPARK-13309

Author: Rahul Tanwani <rahul@Rahuls-MacBook-Pro.local>

Closes #11194 from tanwanirahul/master.
parent 6dfc4a76
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.types._
private[csv] object CSVInferSchema {
/**
......@@ -48,7 +47,11 @@ private[csv] object CSVInferSchema {
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
StructField(thisHeader, rootType, nullable = true)
val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
}
StructType(structFields)
......@@ -65,12 +68,8 @@ private[csv] object CSVInferSchema {
}
def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = {
first.zipAll(second, NullType, NullType).map { case ((a, b)) =>
val tpe = findTightestCommonType(a, b).getOrElse(StringType)
tpe match {
case _: NullType => StringType
case other => other
}
first.zipAll(second, NullType, NullType).map { case (a, b) =>
findTightestCommonType(a, b).getOrElse(NullType)
}
}
......@@ -140,6 +139,8 @@ private[csv] object CSVInferSchema {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)
case (StringType, t2) => Some(StringType)
case (t1, StringType) => Some(StringType)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
......@@ -150,7 +151,6 @@ private[csv] object CSVInferSchema {
}
}
private[csv] object CSVTypeCast {
/**
......
A,B,C,D
1,,,
,1,,
,,1,
,,,1
......@@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
}
test("Merging Nulltypes should yeild Nulltype.") {
val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
assert(mergedNullTypes.deep == Array(NullType).deep)
}
}
......@@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv"
private val simpleSparseFile = "simple_sparse.csv"
private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString
......@@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema.fieldNames.size === 1)
}
test("DDL test with empty file") {
sqlContext.sql(s"""
|CREATE TEMPORARY TABLE carsTable
......@@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(carsCopy, withHeader = true)
}
}
test("Schema inference correctly identifies the datatype when data is sparse.") {
val df = sqlContext.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(simpleSparseFile))
assert(
df.schema.fields.map(field => field.dataType).deep ==
Array(IntegerType, IntegerType, IntegerType, IntegerType).deep)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment