diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ee48baa59c7af63450850e0a9b707b7404f0fd09..c669c2e2e26efbfa85301aee877b09f65c379cf6 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -2684,7 +2684,7 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. expect_error(read.df(source = "json"), - paste("Error in loadDF : analysis error - Unable to infer schema for JSON at .", + paste("Error in loadDF : analysis error - Unable to infer schema for JSON.", "It must be specified manually")) expect_error(read.df("arbitrary_path"), "Error in loadDF : analysis error - Path does not exist") expect_error(read.json("arbitrary_path"), "Error in json : analysis error - Path does not exist") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 84fde0bbf9268f26831cc0626d97e9137eaee604..dbc3e712332f79e6149b4be9c7c17079755d7a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -61,8 +61,12 @@ import org.apache.spark.util.Utils * qualified. This option only works when reading from a [[FileFormat]]. * @param userSpecifiedSchema An optional specification of the schema of the data. When present * we skip attempting to infer the schema. - * @param partitionColumns A list of column names that the relation is partitioned by. When this - * list is empty, the relation is unpartitioned. + * @param partitionColumns A list of column names that the relation is partitioned by. This list is + * generally empty during the read path, unless this DataSource is managed + * by Hive. In these cases, during `resolveRelation`, we will call + * `getOrInferFileFormatSchema` for file based DataSources to infer the + * partitioning. In other cases, if this list is empty, then this table + * is unpartitioned. * @param bucketSpec An optional specification for bucketing (hash-partitioning) of the data. * @param catalogTable Optional catalog table reference that can be used to push down operations * over the datasource to the catalog service. @@ -84,30 +88,106 @@ case class DataSource( private val caseInsensitiveOptions = new CaseInsensitiveMap(options) /** - * Infer the schema of the given FileFormat, returns a pair of schema and partition column names. + * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer + * it. In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510. + * This method will try to skip file scanning whether `userSpecifiedSchema` and + * `partitionColumns` are provided. Here are some code paths that use this method: + * 1. `spark.read` (no schema): Most amount of work. Infer both schema and partitioning columns + * 2. `spark.read.schema(userSpecifiedSchema)`: Parse partitioning columns, cast them to the + * dataTypes provided in `userSpecifiedSchema` if they exist or fallback to inferred + * dataType if they don't. + * 3. `spark.readStream.schema(userSpecifiedSchema)`: For streaming use cases, users have to + * provide the schema. Here, we also perform partition inference like 2, and try to use + * dataTypes in `userSpecifiedSchema`. All subsequent triggers for this stream will re-use + * this information, therefore calls to this method should be very cheap, i.e. there won't + * be any further inference in any triggers. + * 4. `df.saveAsTable(tableThatExisted)`: In this case, we call this method to resolve the + * existing table's partitioning scheme. This is achieved by not providing + * `userSpecifiedSchema`. For this case, we add the boolean `justPartitioning` for an early + * exit, if we don't care about the schema of the original table. + * + * @param format the file format object for this DataSource + * @param justPartitioning Whether to exit early and provide just the schema partitioning. + * @return A pair of the data schema (excluding partition columns) and the schema of the partition + * columns. If `justPartitioning` is `true`, then the dataSchema will be provided as + * `null`. */ - private def inferFileFormatSchema(format: FileFormat): (StructType, Seq[String]) = { - userSpecifiedSchema.map(_ -> partitionColumns).orElse { - val allPaths = caseInsensitiveOptions.get("path") + private def getOrInferFileFormatSchema( + format: FileFormat, + justPartitioning: Boolean = false): (StructType, StructType) = { + // the operations below are expensive therefore try not to do them if we don't need to + lazy val tempFileCatalog = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.toSeq.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - val fileCatalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, None) - val partitionSchema = fileCatalog.partitionSpec().partitionColumns - val inferred = format.inferSchema( + new InMemoryFileIndex(sparkSession, globbedPaths, options, None) + } + val partitionSchema = if (partitionColumns.isEmpty && catalogTable.isEmpty) { + // Try to infer partitioning, because no DataSource in the read path provides the partitioning + // columns properly unless it is a Hive DataSource + val resolved = tempFileCatalog.partitionSchema.map { partitionField => + val equality = sparkSession.sessionState.conf.resolver + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } else { + // in streaming mode, we have already inferred and registered partition columns, we will + // never have to materialize the lazy val below + lazy val inferredPartitions = tempFileCatalog.partitionSchema + // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred + // partitioning + if (userSpecifiedSchema.isEmpty) { + inferredPartitions + } else { + val partitionFields = partitionColumns.map { partitionColumn => + userSpecifiedSchema.flatMap(_.find(_.name == partitionColumn)).orElse { + val inferredOpt = inferredPartitions.find(_.name == partitionColumn) + if (inferredOpt.isDefined) { + logDebug( + s"""Type of partition column: $partitionColumn not found in specified schema + |for $format. + |User Specified Schema + |===================== + |${userSpecifiedSchema.orNull} + | + |Falling back to inferred dataType if it exists. + """.stripMargin) + } + inferredPartitions.find(_.name == partitionColumn) + }.getOrElse { + throw new AnalysisException(s"Failed to resolve the schema for $format for " + + s"the partition column: $partitionColumn. It must be specified manually.") + } + } + StructType(partitionFields) + } + } + if (justPartitioning) { + return (null, partitionSchema) + } + val dataSchema = userSpecifiedSchema.map { schema => + val equality = sparkSession.sessionState.conf.resolver + StructType(schema.filterNot(f => partitionSchema.exists(p => equality(p.name, f.name)))) + }.orElse { + format.inferSchema( sparkSession, caseInsensitiveOptions, - fileCatalog.allFiles()) - - inferred.map { inferredSchema => - StructType(inferredSchema ++ partitionSchema) -> partitionSchema.map(_.name) - } + tempFileCatalog.allFiles()) }.getOrElse { - throw new AnalysisException("Unable to infer schema. It must be specified manually.") + throw new AnalysisException( + s"Unable to infer schema for $format. It must be specified manually.") } + (dataSchema, partitionSchema) } /** Returns the name and schema of the source that can be used to continually read data. */ @@ -144,8 +224,8 @@ case class DataSource( "you may be able to create a static DataFrame on that directory with " + "'spark.read.load(directory)' and infer schema from it.") } - val (schema, partCols) = inferFileFormatSchema(format) - SourceInfo(s"FileSource[$path]", schema, partCols) + val (schema, partCols) = getOrInferFileFormatSchema(format) + SourceInfo(s"FileSource[$path]", StructType(schema ++ partCols), partCols.fieldNames) case _ => throw new UnsupportedOperationException( @@ -272,7 +352,7 @@ case class DataSource( HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSpec().partitionColumns, + partitionSchema = fileCatalog.partitionSchema, dataSchema = dataSchema, bucketSpec = None, format, @@ -281,9 +361,10 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() val globbedPaths = allPaths.flatMap { path => val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = hdfsPath.getFileSystem(hadoopConf) val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val globPath = SparkHadoopUtil.get.globPathIfNecessary(qualified) @@ -291,23 +372,14 @@ case class DataSource( throw new AnalysisException(s"Path does not exist: $qualified") } // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode if (checkFilesExist && !fs.exists(globPath.head)) { throw new AnalysisException(s"Path does not exist: ${globPath.head}") } globPath }.toArray - // If they gave a schema, then we try and figure out the types of the partition columns - // from that schema. - val partitionSchema = userSpecifiedSchema.map { schema => - StructType( - partitionColumns.map { c => - // TODO: Case sensitivity. - schema - .find(_.name.toLowerCase() == c.toLowerCase()) - .getOrElse(throw new AnalysisException(s"Invalid partition column '$c'")) - }) - } + val (dataSchema, inferredPartitionSchema) = getOrInferFileFormatSchema(format) val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { @@ -316,27 +388,12 @@ case class DataSource( catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(0L)) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, partitionSchema) - } - - val dataSchema = userSpecifiedSchema.map { schema => - val equality = sparkSession.sessionState.conf.resolver - StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) - }.orElse { - format.inferSchema( - sparkSession, - caseInsensitiveOptions, - fileCatalog.asInstanceOf[InMemoryFileIndex].allFiles()) - }.getOrElse { - throw new AnalysisException( - s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + - "It must be specified manually") + new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(inferredPartitionSchema)) } HadoopFsRelation( fileCatalog, - partitionSchema = fileCatalog.partitionSchema, + partitionSchema = inferredPartitionSchema, dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, @@ -384,11 +441,7 @@ case class DataSource( // up. If we fail to load the table for whatever reason, ignore the check. if (mode == SaveMode.Append) { val existingPartitionColumns = Try { - resolveRelation() - .asInstanceOf[HadoopFsRelation] - .partitionSchema - .fieldNames - .toSeq + getOrInferFileFormatSchema(format, justPartitioning = true)._2.fieldNames.toList }.getOrElse(Seq.empty[String]) // TODO: Case sensitivity. val sameColumns = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 02d9d1568490415294b37b8e1942fc4e4cffb7ad..10843e9ba57535e180dc85bdbafb2bca128a816e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -274,7 +274,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { pathToPartitionedTable, userSpecifiedSchema = Option("num int, str string"), userSpecifiedPartitionCols = partitionCols, - expectedSchema = new StructType().add("num", IntegerType).add("str", StringType), + expectedSchema = new StructType().add("str", StringType).add("num", IntegerType), expectedPartitionCols = partitionCols.map(Seq(_)).getOrElse(Seq.empty[String])) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index a099153d2e58ed98e9c12284b9e7fcb5bbec1737..bad6642ea40580480c0be5a6e3b6eaba9b27471e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -282,7 +282,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { createFileStreamSourceAndGetSchema( format = Some("json"), path = Some(src.getCanonicalPath), schema = None) } - assert("Unable to infer schema. It must be specified manually.;" === e.getMessage) + assert("Unable to infer schema for JSON. It must be specified manually.;" === e.getMessage) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 5630464f4080395f4c99ad03b69559e5d510c8e0..0eb95a02432fb6f58fad4b789f5183d19b246148 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, StreamingQuery, StreamTest} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils object LastOptions { @@ -532,4 +532,47 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { assert(e.getMessage.contains("does not support recovering")) assert(e.getMessage.contains("checkpoint location")) } + + test("SPARK-18510: use user specified types for partition columns in file sources") { + import org.apache.spark.sql.functions.udf + import testImplicits._ + withTempDir { src => + val createArray = udf { (length: Long) => + for (i <- 1 to length.toInt) yield i.toString + } + spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + .partitionBy("part", "id") + .mode("overwrite") + .parquet(src.toString) + // Specify a random ordering of the schema, partition column in the middle, etc. + // Also let's say that the partition columns are Strings instead of Longs. + // partition columns should go to the end + val schema = new StructType() + .add("id", StringType) + .add("ex", ArrayType(StringType)) + + val sdf = spark.readStream + .schema(schema) + .format("parquet") + .load(src.toString) + + assert(sdf.schema.toList === List( + StructField("ex", ArrayType(StringType)), + StructField("part", IntegerType), // inferred partitionColumn dataType + StructField("id", StringType))) // used user provided partitionColumn dataType + + val sq = sdf.writeStream + .queryName("corruption_test") + .format("memory") + .start() + sq.processAllAvailable() + checkAnswer( + spark.table("corruption_test"), + // notice how `part` is ordered before `id` + Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") :: + Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil + ) + sq.stop() + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index a7fda0109856021d7390a6ef270d4fc6a9d93e54..e0887e0f1c7dee1ef401817119685c18703924fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -573,4 +573,40 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } } + + test("SPARK-18510: use user specified types for partition columns in file sources") { + import org.apache.spark.sql.functions.udf + import testImplicits._ + withTempDir { src => + val createArray = udf { (length: Long) => + for (i <- 1 to length.toInt) yield i.toString + } + spark.range(4).select(createArray('id + 1) as 'ex, 'id, 'id % 4 as 'part).coalesce(1).write + .partitionBy("part", "id") + .mode("overwrite") + .parquet(src.toString) + // Specify a random ordering of the schema, partition column in the middle, etc. + // Also let's say that the partition columns are Strings instead of Longs. + // partition columns should go to the end + val schema = new StructType() + .add("id", StringType) + .add("ex", ArrayType(StringType)) + val df = spark.read + .schema(schema) + .format("parquet") + .load(src.toString) + + assert(df.schema.toList === List( + StructField("ex", ArrayType(StringType)), + StructField("part", IntegerType), // inferred partitionColumn dataType + StructField("id", StringType))) // used user provided partitionColumn dataType + + checkAnswer( + df, + // notice how `part` is ordered before `id` + Row(Array("1"), 0, "0") :: Row(Array("1", "2"), 1, "1") :: + Row(Array("1", "2", "3"), 2, "2") :: Row(Array("1", "2", "3", "4"), 3, "3") :: Nil + ) + } + } }