From ccf536f903ef1f81fb3e1b6ce781d5e40d0ae3e0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Wed, 21 Oct 2015 11:06:34 -0700 Subject: [PATCH] [SPARK-11216] [SQL] add encoder/decoder for external row Implement encode/decode for external row based on `ClassEncoder`. TODO: * code cleanup * ~~fix corner cases~~ * refactor the encoder interface * improve test for product codegen, to cover more corner cases. Author: Wenchen Fan <wenchen@databricks.com> Closes #9184 from cloud-fan/encoder. --- .../spark/sql/catalyst/ScalaReflection.scala | 6 +- .../sql/catalyst/encoders/ClassEncoder.scala | 75 ++++++ .../spark/sql/catalyst/encoders/Encoder.scala | 2 +- .../catalyst/encoders/ProductEncoder.scala | 46 +--- .../sql/catalyst/encoders/RowEncoder.scala | 234 ++++++++++++++++++ .../sql/catalyst/expressions/objects.scala | 46 +++- .../spark/sql/types/ArrayBasedMapData.scala | 4 + .../spark/sql/RandomDataGenerator.scala | 4 +- .../catalyst/encoders/RowEncoderSuite.scala | 96 +++++++ 9 files changed, 459 insertions(+), 54 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8edd6498e5..27c96f4122 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.Utils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils /** * A default version of ScalaReflection that uses the runtime universe. @@ -142,7 +142,7 @@ trait ScalaReflection { } /** - * Returns an expression that can be used to construct an object of type `T` given a an input + * Returns an expression that can be used to construct an object of type `T` given an input * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala new file mode 100644 index 0000000000..f3a1063871 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types.{ObjectType, StructType} + +/** + * A generic encoder for JVM objects. + * + * @param schema The schema after converting `T` to a Spark SQL row. + * @param extractExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object. + * @param clsTag A classtag for `T`. + */ +case class ClassEncoder[T]( + schema: StructType, + extractExpressions: Seq[Expression], + constructExpression: Expression, + clsTag: ClassTag[T]) + extends Encoder[T] { + + private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private val inputRow = new GenericMutableRow(1) + + private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + private val dataType = ObjectType(clsTag.runtimeClass) + + override def toRow(t: T): InternalRow = { + if (t == null) { + null + } else { + inputRow(0) = t + extractProjection(inputRow) + } + } + + override def fromRow(row: InternalRow): T = { + if (row eq null) { + null.asInstanceOf[T] + } else { + constructProjection(row).get(0, dataType).asInstanceOf[T] + } + } + + override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { + val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + val resolvedExpression = analyzedPlan.expressions.head.children.head + val boundExpression = BindReferences.bindReference(resolvedExpression, schema) + + copy(constructExpression = boundExpression) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index 3618247d5d..bdb1c0959d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -46,7 +46,7 @@ trait Encoder[T] { /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must bind` and encoder to a specific schema before you can call this function. + * you must bind the encoder to a specific schema before you can call this function. */ def fromRow(row: InternalRow): T diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index b0381880c3..4f7ce455ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -17,15 +17,11 @@ package org.apache.spark.sql.catalyst.encoders -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types.{ObjectType, StructType} /** @@ -44,44 +40,6 @@ object ProductEncoder { val constructExpression = ScalaReflection.constructorFor[T] new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls)) } -} - -/** - * A generic encoder for JVM objects. - * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. - * @param clsTag A classtag for `T`. - */ -case class ClassEncoder[T]( - schema: StructType, - extractExpressions: Seq[Expression], - constructExpression: Expression, - clsTag: ClassTag[T]) - extends Encoder[T] { - private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) - private val inputRow = new GenericMutableRow(1) - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) - private val dataType = ObjectType(clsTag.runtimeClass) - - override def toRow(t: T): InternalRow = { - inputRow(0) = t - extractProjection(inputRow) - } - - override def fromRow(row: InternalRow): T = { - constructProjection(row).get(0, dataType).asInstanceOf[T] - } - - override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { - val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - val resolvedExpression = analyzedPlan.expressions.head.children.head - val boundExpression = BindReferences.bindReference(resolvedExpression, schema) - - copy(constructExpression = boundExpression) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala new file mode 100644 index 0000000000..3e74aabd07 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.collection.Map +import scala.reflect.ClassTag + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +object RowEncoder { + + def apply(schema: StructType): ClassEncoder[Row] = { + val cls = classOf[Row] + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val extractExpressions = extractorsFor(inputObject, schema) + val constructExpression = constructorFor(schema) + new ClassEncoder[Row]( + schema, + extractExpressions.asInstanceOf[CreateStruct].children, + constructExpression, + ClassTag(cls)) + } + + private def extractorsFor( + inputObject: Expression, + inputType: DataType): Expression = inputType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => inputObject + + case TimestampType => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case DateType => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case _: DecimalType => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case StringType => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t @ ArrayType(et, _) => et match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => + NewInstance( + classOf[GenericArrayData], + inputObject :: Nil, + dataType = t) + case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et)) + } + + case t @ MapType(kt, vt, valueNullable) => + val keys = + Invoke( + Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = extractorsFor(keys, ArrayType(kt, false)) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable)) + + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = t) + + case StructType(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + If( + Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, f.dataType), + extractorsFor( + Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil), + f.dataType)) + } + CreateStruct(convertedFields) + } + + private def externalDataTypeFor(dt: DataType): DataType = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => dt + case TimestampType => ObjectType(classOf[java.sql.Timestamp]) + case DateType => ObjectType(classOf[java.sql.Date]) + case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) + case StringType => ObjectType(classOf[java.lang.String]) + case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]]) + case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]]) + case _: StructType => ObjectType(classOf[Row]) + } + + private def constructorFor(schema: StructType): Expression = { + val fields = schema.zipWithIndex.map { case (f, i) => + val field = BoundReference(i, f.dataType, f.nullable) + If( + IsNull(field), + Literal.create(null, externalDataTypeFor(f.dataType)), + constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) + ) + } + CreateRow(fields) + } + + private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType | BinaryType => input + + case TimestampType => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + input :: Nil) + + case DateType => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + input :: Nil) + + case _: DecimalType => + Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case StringType => + Invoke(input, "toString", ObjectType(classOf[String])) + + case ArrayType(et, nullable) => + val arrayData = + Invoke( + MapObjects(constructorFor(_, et), input, et), + "array", + ObjectType(classOf[Array[_]])) + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + case MapType(kt, vt, valueNullable) => + val keyArrayType = ArrayType(kt, false) + val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType) + + val valueArrayType = ArrayType(vt, valueNullable) + val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case StructType(fields) => + val convertedFields = fields.zipWithIndex.map { case (f, i) => + If( + Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), + Literal.create(null, externalDataTypeFor(f.dataType)), + constructorFor(getField(input, i, f.dataType), f.dataType)) + } + CreateRow(convertedFields) + } + + private def getField( + row: Expression, + ordinal: Int, + dataType: DataType): Expression = dataType match { + case BooleanType => + Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil) + case ByteType => + Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil) + case ShortType => + Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil) + case IntegerType | DateType => + Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil) + case LongType | TimestampType => + Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil) + case FloatType => + Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil) + case DoubleType => + Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil) + case t: DecimalType => + Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_))) + case StringType => + Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil) + case BinaryType => + Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil) + case CalendarIntervalType => + Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil) + case t: StructType => + Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil) + case _: ArrayType => + Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil) + case _: MapType => + Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index e8c1c93cf5..8fc00ad1bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import scala.language.existentials -import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ @@ -364,6 +365,10 @@ case class MapObjects( (".numElements()", (i: String) => s".getShort($i)", true) case ArrayType(BooleanType, _) => (".numElements()", (i: String) => s".getBoolean($i)", true) + case ArrayType(StringType, _) => + (".numElements()", (i: String) => s".getUTF8String($i)", false) + case ArrayType(_: MapType, _) => + (".numElements()", (i: String) => s".getMap($i)", false) } override def nullable: Boolean = true @@ -398,7 +403,7 @@ case class MapObjects( val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.javaType(boundFunction.dataType) + val convertedType = ctx.boxedType(boundFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -434,9 +439,13 @@ case class MapObjects( ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck - ${genFunction.code} + if ($loopIsNull) { + $convertedArray[$loopIndex] = null; + } else { + ${genFunction.code} + $convertedArray[$loopIndex] = ${genFunction.value}; + } - $convertedArray[$loopIndex] = ($convertedType)${genFunction.value}; $loopIndex += 1; } @@ -446,3 +455,32 @@ case class MapObjects( """ } } + +case class CreateRow(children: Seq[Expression]) extends Expression { + override def dataType: DataType = ObjectType(classOf[Row]) + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported") + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val rowClass = classOf[GenericRow].getName + val values = ctx.freshName("values") + s""" + boolean ${ev.isNull} = false; + final Object[] $values = new Object[${children.size}]; + """ + + children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" + if (${eval.isNull}) { + $values[$i] = null; + } else { + $values[$i] = ${eval.value}; + } + """ + }.mkString("\n") + + s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);" + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala index 5f22e59d5f..e5ffe32217 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala @@ -66,4 +66,8 @@ object ArrayBasedMapData { def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = { keys.zip(values).toMap } + + def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = { + keys.zip(values).toMap + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index e48395028e..7614f055e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -148,7 +148,7 @@ object RandomDataGenerator { () => BigDecimal.apply( rand.nextLong() % math.pow(10, precision).toLong, scale, - new MathContext(precision))) + new MathContext(precision)).bigDecimal) case DoubleType => randomNumeric[Double]( rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) @@ -166,7 +166,7 @@ object RandomDataGenerator { case NullType => Some(() => null) case ArrayType(elementType, containsNull) => { forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { - elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) + elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } } case MapType(keyType, valueType, valueContainsNull) => { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala new file mode 100644 index 0000000000..6041b62b74 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class RowEncoderSuite extends SparkFunSuite { + + private val structOfString = new StructType().add("str", StringType) + private val arrayOfString = ArrayType(StringType) + private val mapOfString = MapType(StringType, StringType) + + encodeDecodeTest( + new StructType() + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("decimal", DecimalType.SYSTEM_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType)) + + encodeDecodeTest( + new StructType() + .add("arrayOfString", arrayOfString) + .add("arrayOfArrayOfString", ArrayType(arrayOfString)) + .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) + .add("arrayOfMap", ArrayType(mapOfString)) + .add("arrayOfStruct", ArrayType(structOfString))) + + encodeDecodeTest( + new StructType() + .add("mapOfIntAndString", MapType(IntegerType, StringType)) + .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) + .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) + .add("mapOfArray", MapType(arrayOfString, arrayOfString)) + .add("mapOfStringAndStruct", MapType(StringType, structOfString)) + .add("mapOfStructAndString", MapType(structOfString, StringType)) + .add("mapOfStruct", MapType(structOfString, structOfString))) + + encodeDecodeTest( + new StructType() + .add("structOfString", structOfString) + .add("structOfStructOfString", new StructType().add("struct", structOfString)) + .add("structOfArray", new StructType().add("array", arrayOfString)) + .add("structOfMap", new StructType().add("map", mapOfString)) + .add("structOfArrayAndMap", + new StructType().add("array", arrayOfString).add("map", mapOfString))) + + private def encodeDecodeTest(schema: StructType): Unit = { + test(s"encode/decode: ${schema.simpleString}") { + val encoder = RowEncoder(schema) + val inputGenerator = RandomDataGenerator.forType(schema).get + + var input: Row = null + try { + for (_ <- 1 to 5) { + input = inputGenerator.apply().asInstanceOf[Row] + val row = encoder.toRow(input) + val convertedBack = encoder.fromRow(row) + assert(input == convertedBack) + } + } catch { + case e: Exception => + fail( + s""" + |schema: ${schema.simpleString} + |input: ${input} + """.stripMargin, e) + } + } + } +} -- GitLab