From f48273c13c9e9fea2d9bb6dda10fcaaaaa50c588 Mon Sep 17 00:00:00 2001 From: Michal Senkyr <mike.senkyr@gmail.com> Date: Mon, 12 Jun 2017 08:53:23 +0800 Subject: [PATCH] [SPARK-18891][SQL] Support for specific Java List subtypes ## What changes were proposed in this pull request? Add support for specific Java `List` subtypes in deserialization as well as a generic implicit encoder. All `List` subtypes are supported by using either the size-specifying constructor (one `int` parameter) or the default constructor. Interfaces/abstract classes use the following implementations: * `java.util.List`, `java.util.AbstractList` or `java.util.AbstractSequentialList` => `java.util.ArrayList` ## How was this patch tested? ```bash build/mvn -DskipTests clean package && dev/run-tests ``` Additionally in Spark shell: ``` scala> val jlist = new java.util.LinkedList[Int]; jlist.add(1) jlist: java.util.LinkedList[Int] = [1] res0: Boolean = true scala> Seq(jlist).toDS().map(_.element()).collect() res1: Array[Int] = Array(1) ``` Author: Michal Senkyr <mike.senkyr@gmail.com> Closes #18009 from michalsenkyr/dataset-java-lists. --- .../sql/catalyst/JavaTypeInference.scala | 15 ++--- .../expressions/objects/objects.scala | 19 +++++- .../apache/spark/sql/JavaDatasetSuite.java | 61 +++++++++++++++++++ 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 86a73a319e..7683ee7074 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -267,16 +267,11 @@ object JavaTypeInference { case c if listType.isAssignableFrom(typeToken) => val et = elementType(typeToken) - val array = - Invoke( - MapObjects( - p => deserializerFor(et, Some(p)), - getPath, - inferDataType(et)._1), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil) + MapObjects( + p => deserializerFor(et, Some(p)), + getPath, + inferDataType(et)._1, + customCollectionCls = Some(c)) case _ if mapType.isAssignableFrom(typeToken) => val (keyType, valueType) = mapKeyValueType(typeToken) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 79b7b9f3d0..5bb0febc94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -22,6 +22,7 @@ import java.lang.reflect.Modifier import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ @@ -597,8 +598,8 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) => - // collection + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => + // Scala sequence val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( @@ -609,6 +610,20 @@ case class MapObjects private( genValue => s"$builder.$$plus$$eq($genValue);", s"(${cls.getName}) $builder.result();" ) + case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => + // Java list + val builder = ctx.freshName("collectionBuilder") + ( + if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] || + cls == classOf[java.util.AbstractSequentialList[_]]) { + s"${cls.getName} $builder = new java.util.ArrayList($dataLength);" + } else { + val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") + s"${cls.getName} $builder = new ${cls.getName}($param);" + }, + genValue => s"$builder.add($genValue);", + s"$builder;" + ) case None => // array ( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 3ba37addfc..4ca3b6406a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -1399,4 +1399,65 @@ public class JavaDatasetSuite implements Serializable { ds1.map((MapFunction<NestedSmallBean, NestedSmallBean>) b -> b, encoder); Assert.assertEquals(beans, ds2.collectAsList()); } + + @Test + public void testSpecificLists() { + SpecificListsBean bean = new SpecificListsBean(); + ArrayList<Integer> arrayList = new ArrayList<>(); + arrayList.add(1); + bean.setArrayList(arrayList); + LinkedList<Integer> linkedList = new LinkedList<>(); + linkedList.add(1); + bean.setLinkedList(linkedList); + bean.setList(Collections.singletonList(1)); + List<SpecificListsBean> beans = Collections.singletonList(bean); + Dataset<SpecificListsBean> dataset = + spark.createDataset(beans, Encoders.bean(SpecificListsBean.class)); + Assert.assertEquals(beans, dataset.collectAsList()); + } + + public static class SpecificListsBean implements Serializable { + private ArrayList<Integer> arrayList; + private LinkedList<Integer> linkedList; + private List<Integer> list; + + public ArrayList<Integer> getArrayList() { + return arrayList; + } + + public void setArrayList(ArrayList<Integer> arrayList) { + this.arrayList = arrayList; + } + + public LinkedList<Integer> getLinkedList() { + return linkedList; + } + + public void setLinkedList(LinkedList<Integer> linkedList) { + this.linkedList = linkedList; + } + + public List<Integer> getList() { + return list; + } + + public void setList(List<Integer> list) { + this.list = list; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SpecificListsBean that = (SpecificListsBean) o; + return Objects.equal(arrayList, that.arrayList) && + Objects.equal(linkedList, that.linkedList) && + Objects.equal(list, that.list); + } + + @Override + public int hashCode() { + return Objects.hashCode(arrayList, linkedList, list); + } + } } -- GitLab