diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index 5094058164b2f687fee14db600ed41df7ad53437..5770f59b5307736e7a85a60e50f53660a00c1799 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" - private[spark] override def asNullable: ArrayType = + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || elementType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index f4428c2e8b202fa5702052eac1c0b11231ea47a0..7bcd623b3f33e7daef7ed4e14b5f5fc6a559c34c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType { */ private[spark] def asNullable: DataType + /** + * Returns true if any `DataType` of this DataType tree satisfies the given function `f`. + */ + private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this) + override private[sql] def defaultConcreteType: DataType = this override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index ac34b642827ca7696b6389f6a03ef78ecfc8634c..00461e529ca0a220e24f8030adcbea2917542d26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,8 +62,12 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" - private[spark] override def asNullable: MapType = + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) + + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9cbc207538d4ff0385583fe7308110d437317c5e..d8968ef80639057ef4ef448503d64e0b236e6265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -24,7 +24,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} /** @@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private[sql] def merge(that: StructType): StructType = StructType.merge(this, that).asInstanceOf[StructType] - private[spark] override def asNullable: StructType = { + override private[spark] def asNullable: StructType = { val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => StructField(name, dataType.asNullable, nullable = true, metadata) @@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru StructType(newFields) } + override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { + f(this) || fields.exists(field => field.dataType.existsRecursively(f)) + } + private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 88b221cd81d7415e923b1dbae89f86c2e9d47316..706ecd29d13554a96a49c980780ac37e3025fd4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite { } } + test("existsRecursively") { + val struct = StructType( + StructField("a", LongType) :: + StructField("b", FloatType) :: Nil) + assert(struct.existsRecursively(_.isInstanceOf[LongType])) + assert(struct.existsRecursively(_.isInstanceOf[StructType])) + assert(!struct.existsRecursively(_.isInstanceOf[IntegerType])) + + val mapType = MapType(struct, StringType) + assert(mapType.existsRecursively(_.isInstanceOf[LongType])) + assert(mapType.existsRecursively(_.isInstanceOf[StructType])) + assert(mapType.existsRecursively(_.isInstanceOf[StringType])) + assert(mapType.existsRecursively(_.isInstanceOf[MapType])) + assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType])) + + val arrayType = ArrayType(mapType) + assert(arrayType.existsRecursively(_.isInstanceOf[LongType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StructType])) + assert(arrayType.existsRecursively(_.isInstanceOf[StringType])) + assert(arrayType.existsRecursively(_.isInstanceOf[MapType])) + assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType])) + assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType])) + } + def checkDataTypeJsonRepr(dataType: DataType): Unit = { test(s"JSON - $dataType") { assert(DataType.fromJson(dataType.json) === dataType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 5e5497837a393142d4468a862f2746e473e2202f..6770462bb0ad3714dde2a3020e56ee96f09363a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} -import org.apache.spark.sql.execution.{FileRelation, datasources} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} +import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} @@ -86,9 +85,9 @@ private[hive] object HiveSerDe { serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) val key = source.toLowerCase match { - case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet" - case _ if source.startsWith("org.apache.spark.sql.orc") => "orc" - case _ => source.toLowerCase + case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" + case s if s.startsWith("org.apache.spark.sql.orc") => "orc" + case s => s } serdeMap.get(key) @@ -309,11 +308,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive val hiveTable = (maybeSerDe, dataSource.relation) match { case (Some(serde), relation: HadoopFsRelation) if relation.paths.length == 1 && relation.partitionColumns.isEmpty => - logInfo { - "Persisting data source relation with a single input path into Hive metastore in Hive " + - s"compatible format. Input path: ${relation.paths.head}" + // Hive ParquetSerDe doesn't support decimal type until 1.2.0. + val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet")) + val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType]) + + val hiveParquetSupportsDecimal = client.version match { + case org.apache.spark.sql.hive.client.hive.v1_2 => true + case _ => false + } + + if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) { + // If Hive version is below 1.2.0, we cannot save Hive compatible schema to + // metastore when the file format is Parquet and the schema has DecimalType. + logWarning { + "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " + + "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " + + s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384." + } + newSparkSQLSpecificMetastoreTable() + } else { + logInfo { + "Persisting data source relation with a single input path into Hive metastore in " + + s"Hive compatible format. Input path: ${relation.paths.head}" + } + newHiveCompatibleMetastoreTable(relation, serde) } - newHiveCompatibleMetastoreTable(relation, serde) case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty => logWarning { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala index a82e152dcda2cb796222752d2232b4735afbf128..3811c152a7ae6c15f96dbd0000265d8dcd8dffaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala @@ -88,6 +88,9 @@ private[hive] case class HiveTable( */ private[hive] trait ClientInterface { + /** Returns the Hive Version of this client. */ + def version: HiveVersion + /** Returns the configuration for the given key in the current session. */ def getConf(key: String, defaultValue: String): String diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 3d05b583cf9e0d1391bef5f60a1d811e3e8f225d..f49c97de8ff4e8286ba8bcf95b99600451b2b88a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils} * this ClientWrapper. */ private[hive] class ClientWrapper( - version: HiveVersion, + override val version: HiveVersion, config: Map[String, String], initClassLoader: ClassLoader) extends ClientInterface diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 0503691a44249f591b207fb57f9636b1dfcba7fa..b1b8439efa011e15369f02bc6e86901c444d2459 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -25,7 +25,7 @@ package object client { val exclusions: Seq[String] = Nil) // scalastyle:off - private[client] object hive { + private[hive] object hive { case object v12 extends HiveVersion("0.12.0") case object v13 extends HiveVersion("0.13.1") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 332c3ec0c28b873bb5e58818dec4a454fa85b1ac..59e65ff97b8e023153fc9927f60668f258ce933b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.hive import java.io.File -import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable} +import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.sources.DataSourceTest import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.sql.{Row, SaveMode} import org.apache.spark.{Logging, SparkFunSuite} @@ -55,7 +55,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging { class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils { override val sqlContext = TestHive - private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1) + private val testDF = range(1, 3).select( + ('id + 0.1) cast DecimalType(10, 3) as 'd1, + 'id cast StringType as 'd2 + ).coalesce(1) Seq( "parquet" -> ( @@ -88,10 +91,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -117,10 +120,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) - assert(columns.map(_.hiveType) === Seq("int", "string")) + assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2")) + assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 1e1972d1ac353e3d5c62f3423db884fc0b30d6ed..0c296461144653cb577bb1fa3ad8bc8cda4598c4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -20,16 +20,18 @@ package org.apache.spark.sql.hive import java.io.File import scala.collection.mutable.ArrayBuffer -import scala.sys.process.{ProcessLogger, Process} +import scala.sys.process.{Process, ProcessLogger} +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} +import org.apache.spark.sql.types.DecimalType import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.Matchers -import org.scalatest.concurrent.Timeouts -import org.scalatest.time.SpanSugar._ /** * This suite tests spark-submit with applications using HiveContext. @@ -50,8 +52,8 @@ class HiveSparkSubmitSuite val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) - val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath() - val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath() + val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",") val args = Seq( "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"), @@ -91,6 +93,16 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-9757 Persist Parquet relation with decimal column") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_9757.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -213,7 +225,7 @@ object SparkSQLConfTest extends Logging { // before spark.sql.hive.metastore.jars get set, we will see the following exception: // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only // be used when hive execution version == hive metastore version. - // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars + // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1. val conf = new SparkConf() { override def getAll: Array[(String, String)] = { @@ -239,3 +251,45 @@ object SparkSQLConfTest extends Logging { sc.stop() } } + +object SPARK_9757 extends QueryTest with Logging { + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.sql.hive.metastore.version", "0.13.1") + .set("spark.sql.hive.metastore.jars", "maven")) + + val hiveContext = new TestHiveContext(sparkContext) + import hiveContext.implicits._ + import org.apache.spark.sql.functions._ + + val dir = Utils.createTempDir() + dir.delete() + + try { + { + val df = + hiveContext + .range(10) + .select(('id + 0.1) cast DecimalType(10, 3) as 'dec) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + + { + val df = + hiveContext + .range(10) + .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct) + df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t") + checkAnswer(hiveContext.table("t"), df) + } + } finally { + dir.delete() + hiveContext.sql("DROP TABLE t") + sparkContext.stop() + } + } +}