From 64b1d00e1a7c1dc52c08a5e97baf6e7117f1a94f Mon Sep 17 00:00:00 2001
From: Cheng Lian <lian@databricks.com>
Date: Mon, 12 Oct 2015 10:17:19 -0700
Subject: [PATCH] [SPARK-11007] [SQL] Adds dictionary aware Parquet decimal
 converters

For Parquet decimal columns that are encoded using plain-dictionary encoding, we can make the upper level converter aware of the dictionary, so that we can pre-instantiate all the decimals to avoid duplicated instantiation.

Note that plain-dictionary encoding isn't available for `FIXED_LEN_BYTE_ARRAY` for Parquet writer version `PARQUET_1_0`. So currently only decimals written as `INT32` and `INT64` can benefit from this optimization.

Author: Cheng Lian <lian@databricks.com>

Closes #9040 from liancheng/spark-11007.decimal-converter-dict-support.
---
 .../parquet/CatalystRowConverter.scala        |  83 +++++++++++++++---
 .../src/test/resources/dec-in-i32.parquet     | Bin 0 -> 420 bytes
 .../src/test/resources/dec-in-i64.parquet     | Bin 0 -> 437 bytes
 .../datasources/parquet/ParquetIOSuite.scala  |  19 ++++
 .../ParquetProtobufCompatibilitySuite.scala   |  22 ++---
 .../datasources/parquet/ParquetTest.scala     |   5 ++
 6 files changed, 103 insertions(+), 26 deletions(-)
 create mode 100755 sql/core/src/test/resources/dec-in-i32.parquet
 create mode 100755 sql/core/src/test/resources/dec-in-i64.parquet

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
index 247d35363b..49007e45ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
 import org.apache.parquet.column.Dictionary
 import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter}
 import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8}
-import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{DOUBLE, INT32, INT64, BINARY, FIXED_LEN_BYTE_ARRAY}
 import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type}
 
 import org.apache.spark.Logging
@@ -222,8 +222,25 @@ private[parquet] class CatalystRowConverter(
             updater.setShort(value.asInstanceOf[ShortType#InternalType])
         }
 
+      // For INT32 backed decimals
+      case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT32 =>
+        new CatalystIntDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+
+      // For INT64 backed decimals
+      case t: DecimalType if parquetType.asPrimitiveType().getPrimitiveTypeName == INT64 =>
+        new CatalystLongDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+
+      // For BINARY and FIXED_LEN_BYTE_ARRAY backed decimals
+      case t: DecimalType
+        if parquetType.asPrimitiveType().getPrimitiveTypeName == FIXED_LEN_BYTE_ARRAY ||
+           parquetType.asPrimitiveType().getPrimitiveTypeName == BINARY =>
+        new CatalystBinaryDictionaryAwareDecimalConverter(t.precision, t.scale, updater)
+
       case t: DecimalType =>
-        new CatalystDecimalConverter(t, updater)
+        throw new RuntimeException(
+          s"Unable to create Parquet converter for decimal type ${t.json} whose Parquet type is " +
+            s"$parquetType.  Parquet DECIMAL type can only be backed by INT32, INT64, " +
+            "FIXED_LEN_BYTE_ARRAY, or BINARY.")
 
       case StringType =>
         new CatalystStringConverter(updater)
@@ -274,9 +291,10 @@ private[parquet] class CatalystRowConverter(
           override def set(value: Any): Unit = updater.set(value.asInstanceOf[InternalRow].copy())
         })
 
-      case _ =>
+      case t =>
         throw new RuntimeException(
-          s"Unable to create Parquet converter for data type ${catalystType.json}")
+          s"Unable to create Parquet converter for data type ${t.json} " +
+            s"whose Parquet type is $parquetType")
     }
   }
 
@@ -314,11 +332,18 @@ private[parquet] class CatalystRowConverter(
   /**
    * Parquet converter for fixed-precision decimals.
    */
-  private final class CatalystDecimalConverter(
-      decimalType: DecimalType,
-      updater: ParentContainerUpdater)
+  private abstract class CatalystDecimalConverter(
+      precision: Int, scale: Int, updater: ParentContainerUpdater)
     extends CatalystPrimitiveConverter(updater) {
 
+    protected var expandedDictionary: Array[Decimal] = _
+
+    override def hasDictionarySupport: Boolean = true
+
+    override def addValueFromDictionary(dictionaryId: Int): Unit = {
+      updater.set(expandedDictionary(dictionaryId))
+    }
+
     // Converts decimals stored as INT32
     override def addInt(value: Int): Unit = {
       addLong(value: Long)
@@ -326,18 +351,19 @@ private[parquet] class CatalystRowConverter(
 
     // Converts decimals stored as INT64
     override def addLong(value: Long): Unit = {
-      updater.set(Decimal(value, decimalType.precision, decimalType.scale))
+      updater.set(decimalFromLong(value))
     }
 
     // Converts decimals stored as either FIXED_LENGTH_BYTE_ARRAY or BINARY
     override def addBinary(value: Binary): Unit = {
-      updater.set(toDecimal(value))
+      updater.set(decimalFromBinary(value))
     }
 
-    private def toDecimal(value: Binary): Decimal = {
-      val precision = decimalType.precision
-      val scale = decimalType.scale
+    protected def decimalFromLong(value: Long): Decimal = {
+      Decimal(value, precision, scale)
+    }
 
+    protected def decimalFromBinary(value: Binary): Decimal = {
       if (precision <= CatalystSchemaConverter.MAX_PRECISION_FOR_INT64) {
         // Constructs a `Decimal` with an unscaled `Long` value if possible.
         val unscaled = binaryToUnscaledLong(value)
@@ -371,6 +397,39 @@ private[parquet] class CatalystRowConverter(
     }
   }
 
+  private class CatalystIntDictionaryAwareDecimalConverter(
+      precision: Int, scale: Int, updater: ParentContainerUpdater)
+    extends CatalystDecimalConverter(precision, scale, updater) {
+
+    override def setDictionary(dictionary: Dictionary): Unit = {
+      this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id =>
+        decimalFromLong(dictionary.decodeToInt(id).toLong)
+      }
+    }
+  }
+
+  private class CatalystLongDictionaryAwareDecimalConverter(
+      precision: Int, scale: Int, updater: ParentContainerUpdater)
+    extends CatalystDecimalConverter(precision, scale, updater) {
+
+    override def setDictionary(dictionary: Dictionary): Unit = {
+      this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id =>
+        decimalFromLong(dictionary.decodeToLong(id))
+      }
+    }
+  }
+
+  private class CatalystBinaryDictionaryAwareDecimalConverter(
+      precision: Int, scale: Int, updater: ParentContainerUpdater)
+    extends CatalystDecimalConverter(precision, scale, updater) {
+
+    override def setDictionary(dictionary: Dictionary): Unit = {
+      this.expandedDictionary = Array.tabulate(dictionary.getMaxId + 1) { id =>
+        decimalFromBinary(dictionary.decodeToBinary(id))
+      }
+    }
+  }
+
   /**
    * Parquet converter for arrays.  Spark SQL arrays are represented as Parquet lists.  Standard
    * Parquet lists are represented as a 3-level group annotated by `LIST`:
diff --git a/sql/core/src/test/resources/dec-in-i32.parquet b/sql/core/src/test/resources/dec-in-i32.parquet
new file mode 100755
index 0000000000000000000000000000000000000000..bb5d4af8dd36817bfb2cc16746417f0f2cbec759
GIT binary patch
literal 420
zcmWG=3^EjD5e*Pc^AQyhWno~D@8)2DfaHXP1P{hX!VYT=pEzK^*s-6XkVTmJu$)3z
zLRvxwV-mx>L&=vkfNDh<L={9`bbthlD4QsUj08&yGXsMJ&@cuDF%W@dW>P{zKtf8Q
zr~#MeY(^CZr+?eF2>?}yGD+%q@Dvv$7G=j5CugMQCW<lv1yz|O*fWid;!{$SRk?ts
zb1{f1NXkgcsBy>ub(penut~xdh_Z+&h@D{+YhzO5u)%NwNJfD{Qbs~EzbIWVu^<s>
zi5}QKz2d?gJ)p&frKu%)Mfv4=xv3?IDTyVC63Nv{C6xuKN>)n6B}JvlB}zI<X_=`x
zDaA@w(bY<MiMb#tsPlkwP_;m}X67d5Xqf64X#z#_N^^1&lX8Gcfo7!YD8WouvZ}7F
ljjd&nkbv5)n_Hw%mReMtnV+X%sAr~Uz#z)Vzz_h89{|hbXDk2!

literal 0
HcmV?d00001

diff --git a/sql/core/src/test/resources/dec-in-i64.parquet b/sql/core/src/test/resources/dec-in-i64.parquet
new file mode 100755
index 0000000000000000000000000000000000000000..e07c4a0ad9843dca983e17d784566aae7b4be54e
GIT binary patch
literal 437
zcmWG=3^EjD5naG2n&KlWBFe(RAm7cw00GGf4Gkh1wM<^G4V+$Z2K?fl(wES5p?bj<
zCgYa8#!C!QD$gec0F{a|h$@J>=l}^8Q8rNy83~RSW{3$AFryg6KmtfcCnY2VB%~yY
z8gOaOW>jHt`nPSH0LUmNNgWTK;)2AY?D*p3jMUsjQ6>ga7F8w*_DnOA_>|OSRW6_{
zA`D^*k}{GqY8*16ERv=y9Bh(s1)?ls3S#S+#HKN+aoFH=3P^<lgQSdvW`0q+USdHa
z&@w%+y?VukIeI_`6qcrz=oRIc>*c1FB&H;mBub=IE0t6hq$*h{6_*s1CYLDbD5Yhl
z=A;xWSw&YX<t65Vq@d0O%0blv-JF@5n4@86pkt&76wWKn$w^Gg0jdQWlB%NwGhWH6
ny0$j9mO(-SYPoK1kwRH&QE_H|o`RvCnVtcI93ulm05HM;Rw!wM

literal 0
HcmV?d00001

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 599cf948e7..7274479989 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -488,6 +488,25 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
       clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue))
     }
   }
+
+  test("read dictionary encoded decimals written as INT32") {
+    checkAnswer(
+      // Decimal column in this file is encoded using plain dictionary
+      readResourceParquetFile("dec-in-i32.parquet"),
+      sqlContext.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec))
+  }
+
+  test("read dictionary encoded decimals written as INT64") {
+    checkAnswer(
+      // Decimal column in this file is encoded using plain dictionary
+      readResourceParquetFile("dec-in-i64.parquet"),
+      sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec))
+  }
+
+  // TODO Adds test case for reading dictionary encoded decimals written as `FIXED_LEN_BYTE_ARRAY`
+  // The Parquet writer version Spark 1.6 and prior versions use is `PARQUET_1_0`, which doesn't
+  // provide dictionary encoding support for `FIXED_LEN_BYTE_ARRAY`.  Should add a test here once
+  // we upgrade to `PARQUET_2_0`.
 }
 
 class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
index b290429c2a..98333e58ca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
@@ -17,23 +17,17 @@
 
 package org.apache.spark.sql.execution.datasources.parquet
 
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.Row
 import org.apache.spark.sql.test.SharedSQLContext
 
 class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
-
-  private def readParquetProtobufFile(name: String): DataFrame = {
-    val url = Thread.currentThread().getContextClassLoader.getResource(name)
-    sqlContext.read.parquet(url.toString)
-  }
-
   test("unannotated array of primitive type") {
-    checkAnswer(readParquetProtobufFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
+    checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
   }
 
   test("unannotated array of struct") {
     checkAnswer(
-      readParquetProtobufFile("old-repeated-message.parquet"),
+      readResourceParquetFile("old-repeated-message.parquet"),
       Row(
         Seq(
           Row("First inner", null, null),
@@ -41,14 +35,14 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh
           Row(null, null, "Third inner"))))
 
     checkAnswer(
-      readParquetProtobufFile("proto-repeated-struct.parquet"),
+      readResourceParquetFile("proto-repeated-struct.parquet"),
       Row(
         Seq(
           Row("0 - 1", "0 - 2", "0 - 3"),
           Row("1 - 1", "1 - 2", "1 - 3"))))
 
     checkAnswer(
-      readParquetProtobufFile("proto-struct-with-array-many.parquet"),
+      readResourceParquetFile("proto-struct-with-array-many.parquet"),
       Seq(
         Row(
           Seq(
@@ -66,13 +60,13 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh
 
   test("struct with unannotated array") {
     checkAnswer(
-      readParquetProtobufFile("proto-struct-with-array.parquet"),
+      readResourceParquetFile("proto-struct-with-array.parquet"),
       Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
   }
 
   test("unannotated array of struct with unannotated array") {
     checkAnswer(
-      readParquetProtobufFile("nested-array-struct.parquet"),
+      readResourceParquetFile("nested-array-struct.parquet"),
       Seq(
         Row(2, Seq(Row(1, Seq(Row(3))))),
         Row(5, Seq(Row(4, Seq(Row(6))))),
@@ -81,7 +75,7 @@ class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with Sh
 
   test("unannotated array of string") {
     checkAnswer(
-      readParquetProtobufFile("proto-repeated-string.parquet"),
+      readResourceParquetFile("proto-repeated-string.parquet"),
       Seq(
         Row(Seq("hello", "world")),
         Row(Seq("good", "bye")),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 9840ad919e..8ffb01fc5b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -139,4 +139,9 @@ private[sql] trait ParquetTest extends SQLTestUtils {
       withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { f }
     }
   }
+
+  protected def readResourceParquetFile(name: String): DataFrame = {
+    val url = Thread.currentThread().getContextClassLoader.getResource(name)
+    sqlContext.read.parquet(url.toString)
+  }
 }
-- 
GitLab