Skip to content
Snippets Groups Projects
Commit 903bb8e8 authored by Michal Senkyr's avatar Michal Senkyr Committed by Wenchen Fan
Browse files

[SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a...

[SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException (converting sequence to list)

## What changes were proposed in this pull request?

Added a `to` call at the end of the code generated by `ScalaReflection.deserializerFor` if the requested type is not a supertype of `WrappedArray[_]` that uses `CanBuildFrom[_, _, _]` to convert result into an arbitrary subtype of `Seq[_]`.

Care was taken to preserve the original deserialization where it is possible to avoid the overhead of conversion in cases where it is not needed

`ScalaReflection.serializerFor` could already be used to serialize any `Seq[_]` so it was not altered

`SQLImplicits` had to be altered and new implicit encoders added to permit serialization of other sequence types

Also fixes [SPARK-16815] Dataset[List[T]] leads to ArrayStoreException

## How was this patch tested?
```bash
./build/mvn -DskipTests clean package && ./dev/run-tests
```

Also manual execution of the following sets of commands in the Spark shell:
```scala
case class TestCC(key: Int, letters: List[String])

val ds1 = sc.makeRDD(Seq(
(List("D")),
(List("S","H")),
(List("F","H")),
(List("D","L","L"))
)).map(x=>(x.length,x)).toDF("key","letters").as[TestCC]

val test1=ds1.map{_.key}
test1.show
```

```scala
case class X(l: List[String])
spark.createDataset(Seq(List("A"))).map(X).show
```

```scala
spark.sqlContext.createDataset(sc.parallelize(List(1) :: Nil)).collect
```

After adding arbitrary sequence support also tested with the following commands:

```scala
case class QueueClass(q: scala.collection.immutable.Queue[Int])

spark.createDataset(Seq(List(1,2,3))).map(x => QueueClass(scala.collection.immutable.Queue(x: _*))).map(_.q.dequeue).collect
```

Author: Michal Senkyr <mike.senkyr@gmail.com>

Closes #16240 from michalsenkyr/sql-caseclass-list-fix.
parent bcc510b0
No related branches found
No related tags found
No related merge requests found
......@@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))
StaticInvoke(
val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)
if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag
// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}
Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
}
case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
......
......@@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}
test("SPARK 16792: Get correct deserializer for List[_]") {
val listDeserializer = deserializerFor[List[Int]]
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
}
test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))
import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}
private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]
......
......@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* @since 1.6.0
*/
@InterfaceStability.Evolving
abstract class SQLImplicits {
abstract class SQLImplicits extends LowPrioritySQLImplicits {
protected def _sqlContext: SQLContext
......@@ -45,9 +45,6 @@ abstract class SQLImplicits {
}
}
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]
// Primitives
/** @since 1.6.0 */
......@@ -112,33 +109,96 @@ abstract class SQLImplicits {
// Seqs
/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newIntSequenceEncoder]]
*/
def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newLongSequenceEncoder]]
*/
def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newDoubleSequenceEncoder]]
*/
def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newFloatSequenceEncoder]]
*/
def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newByteSequenceEncoder]]
*/
def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newShortSequenceEncoder]]
*/
def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newBooleanSequenceEncoder]]
*/
def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newStringSequenceEncoder]]
*/
def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
/** @since 1.6.1 */
/**
* @since 1.6.1
* @deprecated use [[newProductSequenceEncoder]]
*/
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
/** @since 2.2.0 */
implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] =
ExpressionEncoder()
/** @since 2.2.0 */
implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
ExpressionEncoder()
// Arrays
/** @since 1.6.1 */
......@@ -193,3 +253,16 @@ abstract class SQLImplicits {
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
}
/**
* Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
* Conflicting implicits are placed here to disambiguate resolution.
*
* Reasons for including specific implicits:
* newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
*/
trait LowPrioritySQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]
}
......@@ -17,10 +17,21 @@
package org.apache.spark.sql
import scala.collection.immutable.Queue
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.test.SharedSQLContext
case class IntClass(value: Int)
case class SeqClass(s: Seq[Int])
case class ListClass(l: List[Int])
case class QueueClass(q: Queue[Int])
case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)
package object packageobject {
case class PackageClass(value: Int)
}
......@@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}
test("arbitrary sequences") {
checkDataset(Seq(Queue(1)).toDS(), Queue(1))
checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
checkDataset(Seq(Queue(true)).toDS(), Queue(true))
checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))
checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
}
test("sequence and product combinations") {
// Case classes
checkDataset(Seq(SeqClass(Seq(1))).toDS(), SeqClass(Seq(1)))
checkDataset(Seq(Seq(SeqClass(Seq(1)))).toDS(), Seq(SeqClass(Seq(1))))
checkDataset(Seq(List(SeqClass(Seq(1)))).toDS(), List(SeqClass(Seq(1))))
checkDataset(Seq(Queue(SeqClass(Seq(1)))).toDS(), Queue(SeqClass(Seq(1))))
checkDataset(Seq(ListClass(List(1))).toDS(), ListClass(List(1)))
checkDataset(Seq(Seq(ListClass(List(1)))).toDS(), Seq(ListClass(List(1))))
checkDataset(Seq(List(ListClass(List(1)))).toDS(), List(ListClass(List(1))))
checkDataset(Seq(Queue(ListClass(List(1)))).toDS(), Queue(ListClass(List(1))))
checkDataset(Seq(QueueClass(Queue(1))).toDS(), QueueClass(Queue(1)))
checkDataset(Seq(Seq(QueueClass(Queue(1)))).toDS(), Seq(QueueClass(Queue(1))))
checkDataset(Seq(List(QueueClass(Queue(1)))).toDS(), List(QueueClass(Queue(1))))
checkDataset(Seq(Queue(QueueClass(Queue(1)))).toDS(), Queue(QueueClass(Queue(1))))
val complex = ComplexClass(SeqClass(Seq(1)), ListClass(List(2)), QueueClass(Queue(3)))
checkDataset(Seq(complex).toDS(), complex)
checkDataset(Seq(Seq(complex)).toDS(), Seq(complex))
checkDataset(Seq(List(complex)).toDS(), List(complex))
checkDataset(Seq(Queue(complex)).toDS(), Queue(complex))
// Tuples
checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2))
checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2))
checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(),
List(Seq("test1") -> List(Queue("test2"))))
// Complex
checkDataset(Seq(ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))).toDS(),
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}
test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment