From 37686539f546ac7a3657dbfc59b7ac982b4b9bce Mon Sep 17 00:00:00 2001
From: hyukjinkwon <gurwls223@gmail.com>
Date: Tue, 18 Oct 2016 13:20:42 -0700
Subject: [PATCH] [SPARK-17388] [SQL] Support for inferring type
 date/timestamp/decimal for partition column

## What changes were proposed in this pull request?

Currently, Spark only supports to infer `IntegerType`, `LongType`, `DoubleType` and `StringType`.

`DecimalType` is being tried but it seems it never infers type as `DecimalType` as `DoubleType` is being tried first. Also, it seems `DateType` and `TimestampType` could be inferred.

As far as I know, it is pretty common to use both for a partition column.

This PR fixes the incorrect `DecimalType` try and also adds the support for both `DateType` and `TimestampType` for inferring partition column type.

## How was this patch tested?

Unit tests in `ParquetPartitionDiscoverySuite`.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #14947 from HyukjinKwon/SPARK-17388.
---
 .../datasources/PartitioningUtils.scala       | 21 ++++++++--
 .../ParquetPartitionDiscoverySuite.scala      | 42 ++++++++++++++++++-
 2 files changed, 59 insertions(+), 4 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 504464216e..381261cf65 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources
 
 import java.lang.{Double => JDouble, Long => JLong}
 import java.math.{BigDecimal => JBigDecimal}
+import java.sql.{Date => JDate, Timestamp => JTimestamp}
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.Try
@@ -307,20 +308,34 @@ object PartitioningUtils {
 
   /**
    * Converts a string to a [[Literal]] with automatic type inference.  Currently only supports
-   * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType.SYSTEM_DEFAULT]], and
-   * [[StringType]].
+   * [[IntegerType]], [[LongType]], [[DoubleType]], [[DecimalType]], [[DateType]]
+   * [[TimestampType]], and [[StringType]].
    */
   private[datasources] def inferPartitionColumnValue(
       raw: String,
       defaultPartitionName: String,
       typeInference: Boolean): Literal = {
+    val decimalTry = Try {
+      // `BigDecimal` conversion can fail when the `field` is not a form of number.
+      val bigDecimal = new JBigDecimal(raw)
+      // It reduces the cases for decimals by disallowing values having scale (eg. `1.1`).
+      require(bigDecimal.scale <= 0)
+      // `DecimalType` conversion can fail when
+      //   1. The precision is bigger than 38.
+      //   2. scale is bigger than precision.
+      Literal(bigDecimal)
+    }
+
     if (typeInference) {
       // First tries integral types
       Try(Literal.create(Integer.parseInt(raw), IntegerType))
         .orElse(Try(Literal.create(JLong.parseLong(raw), LongType)))
+        .orElse(decimalTry)
         // Then falls back to fractional types
         .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
-        .orElse(Try(Literal(new JBigDecimal(raw))))
+        // Then falls back to date/timestamp types
+        .orElse(Try(Literal(JDate.valueOf(raw))))
+        .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw)))))
         // Then falls back to string
         .getOrElse {
           if (raw == defaultPartitionName) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 43357c97c3..2ef66baee1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet
 
 import java.io.File
 import java.math.BigInteger
-import java.sql.Timestamp
+import java.sql.{Date, Timestamp}
 
 import scala.collection.mutable.ArrayBuffer
 
@@ -56,8 +56,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
 
     check("10", Literal.create(10, IntegerType))
     check("1000000000000000", Literal.create(1000000000000000L, LongType))
+    val decimal = Decimal("1" * 20)
+    check("1" * 20,
+      Literal.create(decimal, DecimalType(decimal.precision, decimal.scale)))
     check("1.5", Literal.create(1.5, DoubleType))
     check("hello", Literal.create("hello", StringType))
+    check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType))
+    check("1990-02-24 12:00:30",
+      Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType))
     check(defaultPartitionName, Literal.create(null, NullType))
   }
 
@@ -687,6 +693,40 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
     }
   }
 
+  test("Various inferred partition value types") {
+    val row =
+      Row(
+        Long.MaxValue,
+        4.5,
+        new java.math.BigDecimal(new BigInteger("1" * 20)),
+        java.sql.Date.valueOf("2015-05-23"),
+        java.sql.Timestamp.valueOf("1990-02-24 12:00:30"),
+        "This is a string, /[]?=:",
+        "This is not a partition column")
+
+    val partitionColumnTypes =
+      Seq(
+        LongType,
+        DoubleType,
+        DecimalType(20, 0),
+        DateType,
+        TimestampType,
+        StringType)
+
+    val partitionColumns = partitionColumnTypes.zipWithIndex.map {
+      case (t, index) => StructField(s"p_$index", t)
+    }
+
+    val schema = StructType(partitionColumns :+ StructField(s"i", StringType))
+    val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema)
+
+    withTempPath { dir =>
+      df.write.format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString)
+      val fields = schema.map(f => Column(f.name))
+      checkAnswer(spark.read.load(dir.toString).select(fields: _*), row)
+    }
+  }
+
   test("SPARK-8037: Ignores files whose name starts with dot") {
     withTempPath { dir =>
       val df = (1 to 3).map(i => (i, i, i, i)).toDF("a", "b", "c", "d")
-- 
GitLab