From 69930310115501f0de094fe6f5c6c60dade342bd Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian@databricks.com>
Date: Thu, 13 Aug 2015 16:16:50 +0800
Subject: [PATCH] [SPARK-9757] [SQL] Fixes persistence of Parquet relation with
 decimal column

PR #7967 enables us to save data source relations to metastore in Hive compatible format when possible. But it fails to persist Parquet relations with decimal column(s) to Hive metastore of versions lower than 1.2.0. This is because `ParquetHiveSerDe` in Hive versions prior to 1.2.0 doesn't support decimal. This PR checks for this case and falls back to Spark SQL specific metastore table format.

Author: Yin Huai <yhuai@databricks.com>
Author: Cheng Lian <lian@databricks.com>

Closes #8130 from liancheng/spark-9757/old-hive-parquet-decimal.
---
 .../apache/spark/sql/types/ArrayType.scala    |  6 +-
 .../org/apache/spark/sql/types/DataType.scala |  5 ++
 .../org/apache/spark/sql/types/MapType.scala  |  6 +-
 .../apache/spark/sql/types/StructType.scala   |  8 ++-
 .../spark/sql/types/DataTypeSuite.scala       | 24 +++++++
 .../spark/sql/hive/HiveMetastoreCatalog.scala | 39 ++++++++---
 .../sql/hive/client/ClientInterface.scala     |  3 +
 .../spark/sql/hive/client/ClientWrapper.scala |  2 +-
 .../spark/sql/hive/client/package.scala       |  2 +-
 .../sql/hive/HiveMetastoreCatalogSuite.scala  | 17 +++--
 .../spark/sql/hive/HiveSparkSubmitSuite.scala | 68 +++++++++++++++++--
 11 files changed, 150 insertions(+), 30 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
index 5094058164..5770f59b53 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala
@@ -75,6 +75,10 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
 
   override def simpleString: String = s"array<${elementType.simpleString}>"
 
-  private[spark] override def asNullable: ArrayType =
+  override private[spark] def asNullable: ArrayType =
     ArrayType(elementType.asNullable, containsNull = true)
+
+  override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+    f(this) || elementType.existsRecursively(f)
+  }
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index f4428c2e8b..7bcd623b3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -77,6 +77,11 @@ abstract class DataType extends AbstractDataType {
    */
   private[spark] def asNullable: DataType
 
+  /**
+   * Returns true if any `DataType` of this DataType tree satisfies the given function `f`.
+   */
+  private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = f(this)
+
   override private[sql] def defaultConcreteType: DataType = this
 
   override private[sql] def acceptsType(other: DataType): Boolean = sameType(other)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
index ac34b64282..00461e529c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala
@@ -62,8 +62,12 @@ case class MapType(
 
   override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>"
 
-  private[spark] override def asNullable: MapType =
+  override private[spark] def asNullable: MapType =
     MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
+
+  override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+    f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f)
+  }
 }
 
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 9cbc207538..d8968ef806 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -24,7 +24,7 @@ import org.json4s.JsonDSL._
 
 import org.apache.spark.SparkException
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, AttributeReference, Attribute, InterpretedOrdering$}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering}
 
 
 /**
@@ -292,7 +292,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
   private[sql] def merge(that: StructType): StructType =
     StructType.merge(this, that).asInstanceOf[StructType]
 
-  private[spark] override def asNullable: StructType = {
+  override private[spark] def asNullable: StructType = {
     val newFields = fields.map {
       case StructField(name, dataType, nullable, metadata) =>
         StructField(name, dataType.asNullable, nullable = true, metadata)
@@ -301,6 +301,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
     StructType(newFields)
   }
 
+  override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = {
+    f(this) || fields.exists(field => field.dataType.existsRecursively(f))
+  }
+
   private[sql] val interpretedOrdering = InterpretedOrdering.forSchema(this.fields.map(_.dataType))
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 88b221cd81..706ecd29d1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -170,6 +170,30 @@ class DataTypeSuite extends SparkFunSuite {
     }
   }
 
+  test("existsRecursively") {
+    val struct = StructType(
+      StructField("a", LongType) ::
+      StructField("b", FloatType) :: Nil)
+    assert(struct.existsRecursively(_.isInstanceOf[LongType]))
+    assert(struct.existsRecursively(_.isInstanceOf[StructType]))
+    assert(!struct.existsRecursively(_.isInstanceOf[IntegerType]))
+
+    val mapType = MapType(struct, StringType)
+    assert(mapType.existsRecursively(_.isInstanceOf[LongType]))
+    assert(mapType.existsRecursively(_.isInstanceOf[StructType]))
+    assert(mapType.existsRecursively(_.isInstanceOf[StringType]))
+    assert(mapType.existsRecursively(_.isInstanceOf[MapType]))
+    assert(!mapType.existsRecursively(_.isInstanceOf[IntegerType]))
+
+    val arrayType = ArrayType(mapType)
+    assert(arrayType.existsRecursively(_.isInstanceOf[LongType]))
+    assert(arrayType.existsRecursively(_.isInstanceOf[StructType]))
+    assert(arrayType.existsRecursively(_.isInstanceOf[StringType]))
+    assert(arrayType.existsRecursively(_.isInstanceOf[MapType]))
+    assert(arrayType.existsRecursively(_.isInstanceOf[ArrayType]))
+    assert(!arrayType.existsRecursively(_.isInstanceOf[IntegerType]))
+  }
+
   def checkDataTypeJsonRepr(dataType: DataType): Unit = {
     test(s"JSON - $dataType") {
       assert(DataType.fromJson(dataType.json) === dataType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 5e5497837a..6770462bb0 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -33,15 +33,14 @@ import org.apache.hadoop.hive.ql.plan.TableDesc
 import org.apache.spark.Logging
 import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.planning.PhysicalOperation
 import org.apache.spark.sql.catalyst.plans.logical
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier}
-import org.apache.spark.sql.execution.{FileRelation, datasources}
+import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
 import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource}
+import org.apache.spark.sql.execution.{FileRelation, datasources}
 import org.apache.spark.sql.hive.client._
-import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types._
 import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode}
@@ -86,9 +85,9 @@ private[hive] object HiveSerDe {
           serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")))
 
     val key = source.toLowerCase match {
-      case _ if source.startsWith("org.apache.spark.sql.parquet") => "parquet"
-      case _ if source.startsWith("org.apache.spark.sql.orc") => "orc"
-      case _ => source.toLowerCase
+      case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet"
+      case s if s.startsWith("org.apache.spark.sql.orc") => "orc"
+      case s => s
     }
 
     serdeMap.get(key)
@@ -309,11 +308,31 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
     val hiveTable = (maybeSerDe, dataSource.relation) match {
       case (Some(serde), relation: HadoopFsRelation)
           if relation.paths.length == 1 && relation.partitionColumns.isEmpty =>
-        logInfo {
-          "Persisting data source relation with a single input path into Hive metastore in Hive " +
-            s"compatible format.  Input path: ${relation.paths.head}"
+        // Hive ParquetSerDe doesn't support decimal type until 1.2.0.
+        val isParquetSerDe = serde.inputFormat.exists(_.toLowerCase.contains("parquet"))
+        val hasDecimalFields = relation.schema.existsRecursively(_.isInstanceOf[DecimalType])
+
+        val hiveParquetSupportsDecimal = client.version match {
+          case org.apache.spark.sql.hive.client.hive.v1_2 => true
+          case _ => false
+        }
+
+        if (isParquetSerDe && !hiveParquetSupportsDecimal && hasDecimalFields) {
+          // If Hive version is below 1.2.0, we cannot save Hive compatible schema to
+          // metastore when the file format is Parquet and the schema has DecimalType.
+          logWarning {
+            "Persisting Parquet relation with decimal field(s) into Hive metastore in Spark SQL " +
+              "specific format, which is NOT compatible with Hive. Because ParquetHiveSerDe in " +
+              s"Hive ${client.version.fullVersion} doesn't support decimal type. See HIVE-6384."
+          }
+          newSparkSQLSpecificMetastoreTable()
+        } else {
+          logInfo {
+            "Persisting data source relation with a single input path into Hive metastore in " +
+              s"Hive compatible format. Input path: ${relation.paths.head}"
+          }
+          newHiveCompatibleMetastoreTable(relation, serde)
         }
-        newHiveCompatibleMetastoreTable(relation, serde)
 
       case (Some(serde), relation: HadoopFsRelation) if relation.partitionColumns.nonEmpty =>
         logWarning {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
index a82e152dcd..3811c152a7 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientInterface.scala
@@ -88,6 +88,9 @@ private[hive] case class HiveTable(
  */
 private[hive] trait ClientInterface {
 
+  /** Returns the Hive Version of this client. */
+  def version: HiveVersion
+
   /** Returns the configuration for the given key in the current session. */
   def getConf(key: String, defaultValue: String): String
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
index 3d05b583cf..f49c97de8f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala
@@ -58,7 +58,7 @@ import org.apache.spark.util.{CircularBuffer, Utils}
  *                        this ClientWrapper.
  */
 private[hive] class ClientWrapper(
-    version: HiveVersion,
+    override val version: HiveVersion,
     config: Map[String, String],
     initClassLoader: ClassLoader)
   extends ClientInterface
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
index 0503691a44..b1b8439efa 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala
@@ -25,7 +25,7 @@ package object client {
       val exclusions: Seq[String] = Nil)
 
   // scalastyle:off
-  private[client] object hive {
+  private[hive] object hive {
     case object v12 extends HiveVersion("0.12.0")
     case object v13 extends HiveVersion("0.13.1")
 
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 332c3ec0c2..59e65ff97b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -19,13 +19,13 @@ package org.apache.spark.sql.hive
 
 import java.io.File
 
-import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, ManagedTable}
+import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable}
 import org.apache.spark.sql.hive.test.TestHive
 import org.apache.spark.sql.hive.test.TestHive._
 import org.apache.spark.sql.hive.test.TestHive.implicits._
 import org.apache.spark.sql.sources.DataSourceTest
 import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
 import org.apache.spark.sql.{Row, SaveMode}
 import org.apache.spark.{Logging, SparkFunSuite}
 
@@ -55,7 +55,10 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
 class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
   override val sqlContext = TestHive
 
-  private val testDF = (1 to 2).map(i => (i, s"val_$i")).toDF("d1", "d2").coalesce(1)
+  private val testDF = range(1, 3).select(
+    ('id + 0.1) cast DecimalType(10, 3) as 'd1,
+    'id cast StringType as 'd2
+  ).coalesce(1)
 
   Seq(
     "parquet" -> (
@@ -88,10 +91,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes
 
         val columns = hiveTable.schema
         assert(columns.map(_.name) === Seq("d1", "d2"))
-        assert(columns.map(_.hiveType) === Seq("int", "string"))
+        assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string"))
 
         checkAnswer(table("t"), testDF)
-        assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2"))
+        assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2"))
       }
     }
 
@@ -117,10 +120,10 @@ class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTes
 
           val columns = hiveTable.schema
           assert(columns.map(_.name) === Seq("d1", "d2"))
-          assert(columns.map(_.hiveType) === Seq("int", "string"))
+          assert(columns.map(_.hiveType) === Seq("decimal(10,3)", "string"))
 
           checkAnswer(table("t"), testDF)
-          assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1", "2\tval_2"))
+          assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2"))
         }
       }
     }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
index 1e1972d1ac..0c29646114 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala
@@ -20,16 +20,18 @@ package org.apache.spark.sql.hive
 import java.io.File
 
 import scala.collection.mutable.ArrayBuffer
-import scala.sys.process.{ProcessLogger, Process}
+import scala.sys.process.{Process, ProcessLogger}
 
+import org.scalatest.Matchers
+import org.scalatest.concurrent.Timeouts
 import org.scalatest.exceptions.TestFailedDueToTimeoutException
+import org.scalatest.time.SpanSugar._
 
 import org.apache.spark._
+import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext}
+import org.apache.spark.sql.types.DecimalType
 import org.apache.spark.util.{ResetSystemProperties, Utils}
-import org.scalatest.Matchers
-import org.scalatest.concurrent.Timeouts
-import org.scalatest.time.SpanSugar._
 
 /**
  * This suite tests spark-submit with applications using HiveContext.
@@ -50,8 +52,8 @@ class HiveSparkSubmitSuite
     val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
     val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
     val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
-    val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath()
-    val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath()
+    val jar3 = TestHive.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath
+    val jar4 = TestHive.getHiveFile("hive-hcatalog-core-0.13.1.jar").getCanonicalPath
     val jarsString = Seq(jar1, jar2, jar3, jar4).map(j => j.toString).mkString(",")
     val args = Seq(
       "--class", SparkSubmitClassLoaderTest.getClass.getName.stripSuffix("$"),
@@ -91,6 +93,16 @@ class HiveSparkSubmitSuite
     runSparkSubmit(args)
   }
 
+  test("SPARK-9757 Persist Parquet relation with decimal column") {
+    val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+    val args = Seq(
+      "--class", SPARK_9757.getClass.getName.stripSuffix("$"),
+      "--name", "SparkSQLConfTest",
+      "--master", "local-cluster[2,1,1024]",
+      unusedJar.toString)
+    runSparkSubmit(args)
+  }
+
   // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
   // This is copied from org.apache.spark.deploy.SparkSubmitSuite
   private def runSparkSubmit(args: Seq[String]): Unit = {
@@ -213,7 +225,7 @@ object SparkSQLConfTest extends Logging {
     // before spark.sql.hive.metastore.jars get set, we will see the following exception:
     // Exception in thread "main" java.lang.IllegalArgumentException: Builtin jars can only
     // be used when hive execution version == hive metastore version.
-    // Execution: 0.13.1 != Metastore: 0.12. Specify a vaild path to the correct hive jars
+    // Execution: 0.13.1 != Metastore: 0.12. Specify a valid path to the correct hive jars
     // using $HIVE_METASTORE_JARS or change spark.sql.hive.metastore.version to 0.13.1.
     val conf = new SparkConf() {
       override def getAll: Array[(String, String)] = {
@@ -239,3 +251,45 @@ object SparkSQLConfTest extends Logging {
     sc.stop()
   }
 }
+
+object SPARK_9757 extends QueryTest with Logging {
+  def main(args: Array[String]): Unit = {
+    Utils.configTestLog4j("INFO")
+
+    val sparkContext = new SparkContext(
+      new SparkConf()
+        .set("spark.sql.hive.metastore.version", "0.13.1")
+        .set("spark.sql.hive.metastore.jars", "maven"))
+
+    val hiveContext = new TestHiveContext(sparkContext)
+    import hiveContext.implicits._
+    import org.apache.spark.sql.functions._
+
+    val dir = Utils.createTempDir()
+    dir.delete()
+
+    try {
+      {
+        val df =
+          hiveContext
+            .range(10)
+            .select(('id + 0.1) cast DecimalType(10, 3) as 'dec)
+        df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
+        checkAnswer(hiveContext.table("t"), df)
+      }
+
+      {
+        val df =
+          hiveContext
+            .range(10)
+            .select(callUDF("struct", ('id + 0.2) cast DecimalType(10, 3)) as 'dec_struct)
+        df.write.option("path", dir.getCanonicalPath).mode("overwrite").saveAsTable("t")
+        checkAnswer(hiveContext.table("t"), df)
+      }
+    } finally {
+      dir.delete()
+      hiveContext.sql("DROP TABLE t")
+      sparkContext.stop()
+    }
+  }
+}
-- 
GitLab