Skip to content
Snippets Groups Projects
Commit aa19c696 authored by Rene Treffer's avatar Rene Treffer Committed by Cheng Lian
Browse files

[SPARK-4176] [SQL] Supports decimal types with precision > 18 in Parquet

This PR is based on #6796 authored by rtreffer.

To support large decimal precisions (> 18), we do the following things in this PR:

1. Making `CatalystSchemaConverter` support large decimal precision

   Decimal types with large precision are always converted to fixed-length byte array.

2. Making `CatalystRowConverter` support reading decimal values with large precision

   When the precision is > 18, constructs `Decimal` values with an unscaled `BigInteger` rather than an unscaled `Long`.

3. Making `RowWriteSupport` support writing decimal values with large precision

   In this PR we always write decimals as fixed-length byte array, because Parquet write path hasn't been refactored to conform Parquet format spec (see SPARK-6774 & SPARK-8848).

Two follow-up tasks should be done in future PRs:

- [ ] Writing decimals as `INT32`, `INT64` when possible while fixing SPARK-8848
- [ ] Adding compatibility tests as part of SPARK-5463

Author: Cheng Lian <lian@databricks.com>

Closes #7455 from liancheng/spark-4176 and squashes the following commits:

a543d10 [Cheng Lian] Fixes errors introduced while rebasing
9e31cdf [Cheng Lian] Supports decimals with precision > 18 for Parquet
parent 62283816
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.apache.spark.sql.parquet package org.apache.spark.sql.parquet
import java.math.{BigDecimal, BigInteger}
import java.nio.ByteOrder import java.nio.ByteOrder
import scala.collection.JavaConversions._ import scala.collection.JavaConversions._
...@@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter( ...@@ -263,17 +264,23 @@ private[parquet] class CatalystRowConverter(
val scale = decimalType.scale val scale = decimalType.scale
val bytes = value.getBytes val bytes = value.getBytes
var unscaled = 0L if (precision <= 8) {
var i = 0 // Constructs a `Decimal` with an unscaled `Long` value if possible.
var unscaled = 0L
var i = 0
while (i < bytes.length) { while (i < bytes.length) {
unscaled = (unscaled << 8) | (bytes(i) & 0xff) unscaled = (unscaled << 8) | (bytes(i) & 0xff)
i += 1 i += 1
} }
val bits = 8 * bytes.length val bits = 8 * bytes.length
unscaled = (unscaled << (64 - bits)) >> (64 - bits) unscaled = (unscaled << (64 - bits)) >> (64 - bits)
Decimal(unscaled, precision, scale) Decimal(unscaled, precision, scale)
} else {
// Otherwise, resorts to an unscaled `BigInteger` instead.
Decimal(new BigDecimal(new BigInteger(bytes), scale), precision, scale)
}
} }
} }
......
...@@ -387,24 +387,18 @@ private[parquet] class CatalystSchemaConverter( ...@@ -387,24 +387,18 @@ private[parquet] class CatalystSchemaConverter(
// ===================================== // =====================================
// Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and
// always store decimals in fixed-length byte arrays. // always store decimals in fixed-length byte arrays. To keep compatibility with these older
case DecimalType.Fixed(precision, scale) // versions, here we convert decimals with all precisions to `FIXED_LEN_BYTE_ARRAY` annotated
if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => // by `DECIMAL`.
case DecimalType.Fixed(precision, scale) if !followParquetFormatSpec =>
Types Types
.primitive(FIXED_LEN_BYTE_ARRAY, repetition) .primitive(FIXED_LEN_BYTE_ARRAY, repetition)
.as(DECIMAL) .as(DECIMAL)
.precision(precision) .precision(precision)
.scale(scale) .scale(scale)
.length(minBytesForPrecision(precision)) .length(CatalystSchemaConverter.minBytesForPrecision(precision))
.named(field.name) .named(field.name)
case dec @ DecimalType() if !followParquetFormatSpec =>
throw new AnalysisException(
s"Data type $dec is not supported. " +
s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," +
"decimal precision and scale must be specified, " +
"and precision must be less than or equal to 18.")
// ===================================== // =====================================
// Decimals (follow Parquet format spec) // Decimals (follow Parquet format spec)
// ===================================== // =====================================
...@@ -436,7 +430,7 @@ private[parquet] class CatalystSchemaConverter( ...@@ -436,7 +430,7 @@ private[parquet] class CatalystSchemaConverter(
.as(DECIMAL) .as(DECIMAL)
.precision(precision) .precision(precision)
.scale(scale) .scale(scale)
.length(minBytesForPrecision(precision)) .length(CatalystSchemaConverter.minBytesForPrecision(precision))
.named(field.name) .named(field.name)
// =================================================== // ===================================================
...@@ -548,15 +542,6 @@ private[parquet] class CatalystSchemaConverter( ...@@ -548,15 +542,6 @@ private[parquet] class CatalystSchemaConverter(
Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes
.asInstanceOf[Int] .asInstanceOf[Int]
} }
// Min byte counts needed to store decimals with various precisions
private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision =>
var numBytes = 1
while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
numBytes += 1
}
numBytes
}
} }
...@@ -580,4 +565,23 @@ private[parquet] object CatalystSchemaConverter { ...@@ -580,4 +565,23 @@ private[parquet] object CatalystSchemaConverter {
throw new AnalysisException(message) throw new AnalysisException(message)
} }
} }
private def computeMinBytesForPrecision(precision : Int) : Int = {
var numBytes = 1
while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) {
numBytes += 1
}
numBytes
}
private val MIN_BYTES_FOR_PRECISION = Array.tabulate[Int](39)(computeMinBytesForPrecision)
// Returns the minimum number of bytes needed to store a decimal with a given `precision`.
def minBytesForPrecision(precision : Int) : Int = {
if (precision < MIN_BYTES_FOR_PRECISION.length) {
MIN_BYTES_FOR_PRECISION(precision)
} else {
computeMinBytesForPrecision(precision)
}
}
} }
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.apache.spark.sql.parquet package org.apache.spark.sql.parquet
import java.math.BigInteger
import java.nio.{ByteBuffer, ByteOrder} import java.nio.{ByteBuffer, ByteOrder}
import java.util.{HashMap => JHashMap} import java.util.{HashMap => JHashMap}
...@@ -114,11 +115,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo ...@@ -114,11 +115,8 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes))
case BinaryType => writer.addBinary( case BinaryType => writer.addBinary(
Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) Binary.fromByteArray(value.asInstanceOf[Array[Byte]]))
case d: DecimalType => case DecimalType.Fixed(precision, _) =>
if (d.precision > 18) { writeDecimal(value.asInstanceOf[Decimal], precision)
sys.error(s"Unsupported datatype $d, cannot write to consumer")
}
writeDecimal(value.asInstanceOf[Decimal], d.precision)
case _ => sys.error(s"Do not know how to writer $schema to consumer") case _ => sys.error(s"Do not know how to writer $schema to consumer")
} }
} }
...@@ -199,20 +197,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo ...@@ -199,20 +197,47 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo
writer.endGroup() writer.endGroup()
} }
// Scratch array used to write decimals as fixed-length binary // Scratch array used to write decimals as fixed-length byte array
private[this] val scratchBytes = new Array[Byte](8) private[this] var reusableDecimalBytes = new Array[Byte](16)
private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = {
val numBytes = ParquetTypesConverter.BYTES_FOR_PRECISION(precision) val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision)
val unscaledLong = decimal.toUnscaledLong
var i = 0 def longToBinary(unscaled: Long): Binary = {
var shift = 8 * (numBytes - 1) var i = 0
while (i < numBytes) { var shift = 8 * (numBytes - 1)
scratchBytes(i) = (unscaledLong >> shift).toByte while (i < numBytes) {
i += 1 reusableDecimalBytes(i) = (unscaled >> shift).toByte
shift -= 8 i += 1
shift -= 8
}
Binary.fromByteArray(reusableDecimalBytes, 0, numBytes)
} }
writer.addBinary(Binary.fromByteArray(scratchBytes, 0, numBytes))
def bigIntegerToBinary(unscaled: BigInteger): Binary = {
unscaled.toByteArray match {
case bytes if bytes.length == numBytes =>
Binary.fromByteArray(bytes)
case bytes if bytes.length <= reusableDecimalBytes.length =>
val signedByte = (if (bytes.head < 0) -1 else 0).toByte
java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte)
System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length)
Binary.fromByteArray(reusableDecimalBytes, 0, numBytes)
case bytes =>
reusableDecimalBytes = new Array[Byte](bytes.length)
bigIntegerToBinary(unscaled)
}
}
val binary = if (numBytes <= 8) {
longToBinary(decimal.toUnscaledLong)
} else {
bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue())
}
writer.addBinary(binary)
} }
// array used to write Timestamp as Int96 (fixed-length binary) // array used to write Timestamp as Int96 (fixed-length binary)
...@@ -268,11 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { ...@@ -268,11 +293,8 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport {
writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes))
case BinaryType => case BinaryType =>
writer.addBinary(Binary.fromByteArray(record.getBinary(index))) writer.addBinary(Binary.fromByteArray(record.getBinary(index)))
case d: DecimalType => case DecimalType.Fixed(precision, _) =>
if (d.precision > 18) { writeDecimal(record.getDecimal(index), precision)
sys.error(s"Unsupported datatype $d, cannot write to consumer")
}
writeDecimal(record.getDecimal(index), d.precision)
case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer")
} }
} }
......
...@@ -106,21 +106,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest { ...@@ -106,21 +106,13 @@ class ParquetIOSuite extends QueryTest with ParquetTest {
// Parquet doesn't allow column names with spaces, have to add an alias here // Parquet doesn't allow column names with spaces, have to add an alias here
.select($"_1" cast decimal as "dec") .select($"_1" cast decimal as "dec")
for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17), (19, 0), (38, 37))) {
withTempPath { dir => withTempPath { dir =>
val data = makeDecimalRDD(DecimalType(precision, scale)) val data = makeDecimalRDD(DecimalType(precision, scale))
data.write.parquet(dir.getCanonicalPath) data.write.parquet(dir.getCanonicalPath)
checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq) checkAnswer(sqlContext.read.parquet(dir.getCanonicalPath), data.collect().toSeq)
} }
} }
// Decimals with precision above 18 are not yet supported
intercept[Throwable] {
withTempPath { dir =>
makeDecimalRDD(DecimalType(19, 10)).write.parquet(dir.getCanonicalPath)
sqlContext.read.parquet(dir.getCanonicalPath).collect()
}
}
} }
test("date type") { test("date type") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment