Skip to content
Snippets Groups Projects
Commit 17591d90 authored by Kevin Yu's avatar Kevin Yu Committed by Wenchen Fan
Browse files

[SPARK-11827][SQL] Adding java.math.BigInteger support in Java type inference...

[SPARK-11827][SQL] Adding java.math.BigInteger support in Java type inference for POJOs and Java collections

Hello : Can you help check this PR? I am adding support for the java.math.BigInteger for java bean code path. I saw internally spark is converting the BigInteger to BigDecimal in ColumnType.scala and CatalystRowConverter.scala. I use the similar way and convert the BigInteger to the BigDecimal. .

Author: Kevin Yu <qyu@us.ibm.com>

Closes #10125 from kevinyu98/working_on_spark-11827.
parent d5c47f8f
No related branches found
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst
import java.lang.{Iterable => JavaIterable}
import java.math.{BigDecimal => JavaBigDecimal}
import java.math.{BigInteger => JavaBigInteger}
import java.sql.{Date, Timestamp}
import java.util.{Map => JavaMap}
import javax.annotation.Nullable
......@@ -326,6 +327,7 @@ object CatalystTypeConverters {
val decimal = scalaValue match {
case d: BigDecimal => Decimal(d)
case d: JavaBigDecimal => Decimal(d)
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
}
if (decimal.changePrecision(dataType.precision, dataType.scale)) {
......
......@@ -89,6 +89,7 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true)
case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
......
......@@ -259,6 +259,12 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[BigDecimal] =>
Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal]))
case t if t <:< localTypeOf[java.math.BigInteger] =>
Invoke(getPath, "toJavaBigInteger", ObjectType(classOf[java.math.BigInteger]))
case t if t <:< localTypeOf[scala.math.BigInt] =>
Invoke(getPath, "toScalaBigInt", ObjectType(classOf[scala.math.BigInt]))
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
......@@ -592,6 +598,20 @@ object ScalaReflection extends ScalaReflection {
"apply",
inputObject :: Nil)
case t if t <:< localTypeOf[java.math.BigInteger] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil)
case t if t <:< localTypeOf[scala.math.BigInt] =>
StaticInvoke(
Decimal.getClass,
DecimalType.BigIntDecimal,
"apply",
inputObject :: Nil)
case t if t <:< localTypeOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case t if t <:< localTypeOf[java.lang.Long] =>
......@@ -736,6 +756,10 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigDecimal] =>
Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.math.BigInteger] =>
Schema(DecimalType.BigIntDecimal, nullable = true)
case t if t <:< localTypeOf[scala.math.BigInt] =>
Schema(DecimalType.BigIntDecimal, nullable = true)
case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true)
case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true)
......
......@@ -17,7 +17,7 @@
package org.apache.spark.sql.types
import java.math.{MathContext, RoundingMode}
import java.math.{BigInteger, MathContext, RoundingMode}
import org.apache.spark.annotation.DeveloperApi
......@@ -128,6 +128,23 @@ final class Decimal extends Ordered[Decimal] with Serializable {
this
}
/**
* Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0.
*/
def set(bigintval: BigInteger): Decimal = {
try {
this.decimalVal = null
this.longVal = bigintval.longValueExact()
this._precision = DecimalType.MAX_PRECISION
this._scale = 0
this
}
catch {
case e: ArithmeticException =>
throw new IllegalArgumentException(s"BigInteger ${bigintval} too large for decimal")
}
}
/**
* Set this Decimal to the given Decimal value.
*/
......@@ -155,6 +172,10 @@ final class Decimal extends Ordered[Decimal] with Serializable {
}
}
def toScalaBigInt: BigInt = BigInt(toLong)
def toJavaBigInteger: java.math.BigInteger = java.math.BigInteger.valueOf(toLong)
def toUnscaledLong: Long = {
if (decimalVal.ne(null)) {
decimalVal.underlying().unscaledValue().longValue()
......@@ -371,6 +392,10 @@ object Decimal {
def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value)
def apply(value: java.math.BigInteger): Decimal = new Decimal().set(value)
def apply(value: scala.math.BigInt): Decimal = new Decimal().set(value.bigInteger)
def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
new Decimal().set(value, precision, scale)
......@@ -387,6 +412,8 @@ object Decimal {
value match {
case j: java.math.BigDecimal => apply(j)
case d: BigDecimal => apply(d)
case k: scala.math.BigInt => apply(k)
case l: java.math.BigInteger => apply(l)
case d: Decimal => d
}
}
......
......@@ -117,6 +117,7 @@ object DecimalType extends AbstractDataType {
private[sql] val LongDecimal = DecimalType(20, 0)
private[sql] val FloatDecimal = DecimalType(14, 7)
private[sql] val DoubleDecimal = DecimalType(30, 15)
private[sql] val BigIntDecimal = DecimalType(38, 0)
private[sql] def forType(dataType: DataType): DecimalType = dataType match {
case ByteType => ByteDecimal
......
......@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.encoders
import java.math.BigInteger
import java.sql.{Date, Timestamp}
import java.util.Arrays
......@@ -109,7 +110,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
encodeDecodeTest(BigDecimal("32131413.211321313"), "scala decimal")
encodeDecodeTest(new java.math.BigDecimal("231341.23123"), "java decimal")
encodeDecodeTest(BigInt("23134123123"), "scala biginteger")
encodeDecodeTest(new BigInteger("23134123123"), "java BigInteger")
encodeDecodeTest(Decimal("32131413.211321313"), "catalyst decimal")
encodeDecodeTest("hello", "string")
......
......@@ -21,6 +21,8 @@ import java.io.Serializable;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.*;
import java.math.BigInteger;
import java.math.BigDecimal;
import scala.collection.JavaConverters;
import scala.collection.Seq;
......@@ -130,6 +132,7 @@ public class JavaDataFrameSuite {
private Integer[] b = { 0, 1 };
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
private List<String> d = Arrays.asList("floppy", "disk");
private BigInteger e = new BigInteger("1234567");
public double getA() {
return a;
......@@ -146,6 +149,8 @@ public class JavaDataFrameSuite {
public List<String> getD() {
return d;
}
public BigInteger getE() { return e; }
}
void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) {
......@@ -163,7 +168,9 @@ public class JavaDataFrameSuite {
Assert.assertEquals(
new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
schema.apply("d"));
Row first = df.select("a", "b", "c", "d").first();
Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()),
schema.apply("e"));
Row first = df.select("a", "b", "c", "d", "e").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
......@@ -182,6 +189,8 @@ public class JavaDataFrameSuite {
for (int i = 0; i < d.length(); i++) {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
// Java.math.BigInteger is equavient to Spark Decimal(38,0)
Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4));
}
@Test
......
......@@ -34,7 +34,9 @@ case class ReflectData(
decimalField: java.math.BigDecimal,
date: Date,
timestampField: Timestamp,
seqInt: Seq[Int])
seqInt: Seq[Int],
javaBigInt: java.math.BigInteger,
scalaBigInt: scala.math.BigInt)
case class NullReflectData(
intField: java.lang.Integer,
......@@ -77,13 +79,15 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3),
new java.math.BigInteger("1"), scala.math.BigInt(1))
Seq(data).toDF().createOrReplaceTempView("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
new Timestamp(12345), Seq(1, 2, 3)))
new Timestamp(12345), Seq(1, 2, 3), new java.math.BigDecimal(1),
new java.math.BigDecimal(1)))
}
test("query case class RDD with nulls") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment