Skip to content
Snippets Groups Projects
Commit dd3b5455 authored by Rahul Tanwani's avatar Rahul Tanwani Committed by Reynold Xin
Browse files

[SPARK-13309][SQL] Fix type inference issue with CSV data

Fix type inference issue for sparse CSV data - https://issues.apache.org/jira/browse/SPARK-13309

Author: Rahul Tanwani <rahul@Rahuls-MacBook-Pro.local>

Closes #11194 from tanwanirahul/master.
parent 6dfc4a76
No related branches found
No related tags found
No related merge requests found
...@@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD ...@@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
private[csv] object CSVInferSchema { private[csv] object CSVInferSchema {
/** /**
...@@ -48,7 +47,11 @@ private[csv] object CSVInferSchema { ...@@ -48,7 +47,11 @@ private[csv] object CSVInferSchema {
tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes) tokenRdd.aggregate(startType)(inferRowType(nullValue), mergeRowTypes)
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) => val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
StructField(thisHeader, rootType, nullable = true) val dType = rootType match {
case _: NullType => StringType
case other => other
}
StructField(thisHeader, dType, nullable = true)
} }
StructType(structFields) StructType(structFields)
...@@ -65,12 +68,8 @@ private[csv] object CSVInferSchema { ...@@ -65,12 +68,8 @@ private[csv] object CSVInferSchema {
} }
def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = { def mergeRowTypes(first: Array[DataType], second: Array[DataType]): Array[DataType] = {
first.zipAll(second, NullType, NullType).map { case ((a, b)) => first.zipAll(second, NullType, NullType).map { case (a, b) =>
val tpe = findTightestCommonType(a, b).getOrElse(StringType) findTightestCommonType(a, b).getOrElse(NullType)
tpe match {
case _: NullType => StringType
case other => other
}
} }
} }
...@@ -140,6 +139,8 @@ private[csv] object CSVInferSchema { ...@@ -140,6 +139,8 @@ private[csv] object CSVInferSchema {
case (t1, t2) if t1 == t2 => Some(t1) case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1) case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1) case (t1, NullType) => Some(t1)
case (StringType, t2) => Some(StringType)
case (t1, StringType) => Some(StringType)
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal // Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) => case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
...@@ -150,7 +151,6 @@ private[csv] object CSVInferSchema { ...@@ -150,7 +151,6 @@ private[csv] object CSVInferSchema {
} }
} }
private[csv] object CSVTypeCast { private[csv] object CSVTypeCast {
/** /**
......
A,B,C,D
1,,,
,1,,
,,1,
,,,1
...@@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite { ...@@ -68,4 +68,9 @@ class InferSchemaSuite extends SparkFunSuite {
assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType) assert(CSVInferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType) assert(CSVInferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
} }
test("Merging Nulltypes should yeild Nulltype.") {
val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType))
assert(mergedNullTypes.deep == Array(NullType).deep)
}
} }
...@@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ...@@ -37,6 +37,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
private val emptyFile = "empty.csv" private val emptyFile = "empty.csv"
private val commentsFile = "comments.csv" private val commentsFile = "comments.csv"
private val disableCommentsFile = "disable_comments.csv" private val disableCommentsFile = "disable_comments.csv"
private val simpleSparseFile = "simple_sparse.csv"
private def testFile(fileName: String): String = { private def testFile(fileName: String): String = {
Thread.currentThread().getContextClassLoader.getResource(fileName).toString Thread.currentThread().getContextClassLoader.getResource(fileName).toString
...@@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ...@@ -233,7 +234,6 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
assert(result.schema.fieldNames.size === 1) assert(result.schema.fieldNames.size === 1)
} }
test("DDL test with empty file") { test("DDL test with empty file") {
sqlContext.sql(s""" sqlContext.sql(s"""
|CREATE TEMPORARY TABLE carsTable |CREATE TEMPORARY TABLE carsTable
...@@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { ...@@ -396,4 +396,16 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
verifyCars(carsCopy, withHeader = true) verifyCars(carsCopy, withHeader = true)
} }
} }
test("Schema inference correctly identifies the datatype when data is sparse.") {
val df = sqlContext.read
.format("csv")
.option("header", "true")
.option("inferSchema", "true")
.load(testFile(simpleSparseFile))
assert(
df.schema.fields.map(field => field.dataType).deep ==
Array(IntegerType, IntegerType, IntegerType, IntegerType).deep)
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment