Skip to content
Snippets Groups Projects
Commit 29d92181 authored by Michael Armbrust's avatar Michael Armbrust
Browse files

[SPARK-13094][SQL] Add encoders for seq/array of primitives

Author: Michael Armbrust <michael@databricks.com>

Closes #11014 from marmbrus/seqEncoders.
parent 12a20c14
No related branches found
No related tags found
No related merge requests found
......@@ -39,6 +39,8 @@ abstract class SQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder()
// Primitives
/** @since 1.6.0 */
implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder()
......@@ -56,13 +58,72 @@ abstract class SQLImplicits {
/** @since 1.6.0 */
implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder()
/** @since 1.6.0 */
/** @since 1.6.0 */
implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder()
/** @since 1.6.0 */
implicit def newStringEncoder: Encoder[String] = ExpressionEncoder()
// Seqs
/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
// Arrays
/** @since 1.6.1 */
implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newByteArrayEncoder: Encoder[Array[Byte]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder()
/** @since 1.6.1 */
implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] =
ExpressionEncoder()
/**
* Creates a [[Dataset]] from an RDD.
* @since 1.6.0
......
......@@ -105,4 +105,26 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
agged,
"1", "abc", "3", "xyz", "5", "hello")
}
test("Arrays and Lists") {
checkAnswer(Seq(Seq(1)).toDS(), Seq(1))
checkAnswer(Seq(Seq(1.toLong)).toDS(), Seq(1.toLong))
checkAnswer(Seq(Seq(1.toDouble)).toDS(), Seq(1.toDouble))
checkAnswer(Seq(Seq(1.toFloat)).toDS(), Seq(1.toFloat))
checkAnswer(Seq(Seq(1.toByte)).toDS(), Seq(1.toByte))
checkAnswer(Seq(Seq(1.toShort)).toDS(), Seq(1.toShort))
checkAnswer(Seq(Seq(true)).toDS(), Seq(true))
checkAnswer(Seq(Seq("test")).toDS(), Seq("test"))
checkAnswer(Seq(Seq(Tuple1(1))).toDS(), Seq(Tuple1(1)))
checkAnswer(Seq(Array(1)).toDS(), Array(1))
checkAnswer(Seq(Array(1.toLong)).toDS(), Array(1.toLong))
checkAnswer(Seq(Array(1.toDouble)).toDS(), Array(1.toDouble))
checkAnswer(Seq(Array(1.toFloat)).toDS(), Array(1.toFloat))
checkAnswer(Seq(Array(1.toByte)).toDS(), Array(1.toByte))
checkAnswer(Seq(Array(1.toShort)).toDS(), Array(1.toShort))
checkAnswer(Seq(Array(true)).toDS(), Array(true))
checkAnswer(Seq(Array("test")).toDS(), Array("test"))
checkAnswer(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}
}
......@@ -95,7 +95,13 @@ abstract class QueryTest extends PlanTest {
""".stripMargin, e)
}
if (decoded != expectedAnswer.toSet) {
// Handle the case where the return type is an array
val isArray = decoded.headOption.map(_.getClass.isArray).getOrElse(false)
def normalEquality = decoded == expectedAnswer.toSet
def expectedAsSeq = expectedAnswer.map(_.asInstanceOf[Array[_]].toSeq).toSet
def decodedAsSeq = decoded.map(_.asInstanceOf[Array[_]].toSeq)
if (!((isArray && expectedAsSeq == decodedAsSeq) || normalEquality)) {
val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted
val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment