Skip to content
Snippets Groups Projects
Commit f48273c1 authored by Michal Senkyr's avatar Michal Senkyr Committed by Wenchen Fan
Browse files

[SPARK-18891][SQL] Support for specific Java List subtypes

## What changes were proposed in this pull request?

Add support for specific Java `List` subtypes in deserialization as well as a generic implicit encoder.

All `List` subtypes are supported by using either the size-specifying constructor (one `int` parameter) or the default constructor.

Interfaces/abstract classes use the following implementations:

* `java.util.List`, `java.util.AbstractList` or `java.util.AbstractSequentialList` => `java.util.ArrayList`

## How was this patch tested?

```bash
build/mvn -DskipTests clean package && dev/run-tests
```

Additionally in Spark shell:

```
scala> val jlist = new java.util.LinkedList[Int]; jlist.add(1)
jlist: java.util.LinkedList[Int] = [1]
res0: Boolean = true

scala> Seq(jlist).toDS().map(_.element()).collect()
res1: Array[Int] = Array(1)
```

Author: Michal Senkyr <mike.senkyr@gmail.com>

Closes #18009 from michalsenkyr/dataset-java-lists.
parent 0538f3b0
No related branches found
No related tags found
No related merge requests found
......@@ -267,16 +267,11 @@ object JavaTypeInference {
case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
val array =
Invoke(
MapObjects(
p => deserializerFor(et, Some(p)),
getPath,
inferDataType(et)._1),
"array",
ObjectType(classOf[Array[Any]]))
StaticInvoke(classOf[java.util.Arrays], ObjectType(c), "asList", array :: Nil)
MapObjects(
p => deserializerFor(et, Some(p)),
getPath,
inferDataType(et)._1,
customCollectionCls = Some(c))
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
......
......@@ -22,6 +22,7 @@ import java.lang.reflect.Modifier
import scala.collection.mutable.Builder
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.Try
import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer._
......@@ -597,8 +598,8 @@ case class MapObjects private(
val (initCollection, addElement, getResult): (String, String => String, String) =
customCollectionCls match {
case Some(cls) =>
// collection
case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()"
val builder = ctx.freshName("collectionBuilder")
(
......@@ -609,6 +610,20 @@ case class MapObjects private(
genValue => s"$builder.$$plus$$eq($genValue);",
s"(${cls.getName}) $builder.result();"
)
case Some(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
// Java list
val builder = ctx.freshName("collectionBuilder")
(
if (cls == classOf[java.util.List[_]] || cls == classOf[java.util.AbstractList[_]] ||
cls == classOf[java.util.AbstractSequentialList[_]]) {
s"${cls.getName} $builder = new java.util.ArrayList($dataLength);"
} else {
val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("")
s"${cls.getName} $builder = new ${cls.getName}($param);"
},
genValue => s"$builder.add($genValue);",
s"$builder;"
)
case None =>
// array
(
......
......@@ -1399,4 +1399,65 @@ public class JavaDatasetSuite implements Serializable {
ds1.map((MapFunction<NestedSmallBean, NestedSmallBean>) b -> b, encoder);
Assert.assertEquals(beans, ds2.collectAsList());
}
@Test
public void testSpecificLists() {
SpecificListsBean bean = new SpecificListsBean();
ArrayList<Integer> arrayList = new ArrayList<>();
arrayList.add(1);
bean.setArrayList(arrayList);
LinkedList<Integer> linkedList = new LinkedList<>();
linkedList.add(1);
bean.setLinkedList(linkedList);
bean.setList(Collections.singletonList(1));
List<SpecificListsBean> beans = Collections.singletonList(bean);
Dataset<SpecificListsBean> dataset =
spark.createDataset(beans, Encoders.bean(SpecificListsBean.class));
Assert.assertEquals(beans, dataset.collectAsList());
}
public static class SpecificListsBean implements Serializable {
private ArrayList<Integer> arrayList;
private LinkedList<Integer> linkedList;
private List<Integer> list;
public ArrayList<Integer> getArrayList() {
return arrayList;
}
public void setArrayList(ArrayList<Integer> arrayList) {
this.arrayList = arrayList;
}
public LinkedList<Integer> getLinkedList() {
return linkedList;
}
public void setLinkedList(LinkedList<Integer> linkedList) {
this.linkedList = linkedList;
}
public List<Integer> getList() {
return list;
}
public void setList(List<Integer> list) {
this.list = list;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SpecificListsBean that = (SpecificListsBean) o;
return Objects.equal(arrayList, that.arrayList) &&
Objects.equal(linkedList, that.linkedList) &&
Objects.equal(list, that.list);
}
@Override
public int hashCode() {
return Objects.hashCode(arrayList, linkedList, list);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment