Skip to content
Snippets Groups Projects
Commit ccf536f9 authored by Wenchen Fan's avatar Wenchen Fan Committed by Michael Armbrust
Browse files

[SPARK-11216] [SQL] add encoder/decoder for external row

Implement encode/decode for external row based on `ClassEncoder`.

TODO:
* code cleanup
* ~~fix corner cases~~
* refactor the encoder interface
* improve test for product codegen, to cover more corner cases.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #9184 from cloud-fan/encoder.
parent f62e3260
No related branches found
No related tags found
No related merge requests found
Showing
with 459 additions and 54 deletions
...@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst ...@@ -19,11 +19,11 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/** /**
* A default version of ScalaReflection that uses the runtime universe. * A default version of ScalaReflection that uses the runtime universe.
...@@ -142,7 +142,7 @@ trait ScalaReflection { ...@@ -142,7 +142,7 @@ trait ScalaReflection {
} }
/** /**
* Returns an expression that can be used to construct an object of type `T` given a an input * Returns an expression that can be used to construct an object of type `T` given an input
* row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed * of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue. * using UnresolvedExtractValue.
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.encoders
import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.types.{ObjectType, StructType}
/**
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
* @param extractExpressions A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object.
* @param clsTag A classtag for `T`.
*/
case class ClassEncoder[T](
schema: StructType,
extractExpressions: Seq[Expression],
constructExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private val inputRow = new GenericMutableRow(1)
private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
private val dataType = ObjectType(clsTag.runtimeClass)
override def toRow(t: T): InternalRow = {
if (t == null) {
null
} else {
inputRow(0) = t
extractProjection(inputRow)
}
}
override def fromRow(row: InternalRow): T = {
if (row eq null) {
null.asInstanceOf[T]
} else {
constructProjection(row).get(0, dataType).asInstanceOf[T]
}
}
override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
val resolvedExpression = analyzedPlan.expressions.head.children.head
val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
copy(constructExpression = boundExpression)
}
}
...@@ -46,7 +46,7 @@ trait Encoder[T] { ...@@ -46,7 +46,7 @@ trait Encoder[T] {
/** /**
* Returns an object of type `T`, extracting the required values from the provided row. Note that * Returns an object of type `T`, extracting the required values from the provided row. Note that
* you must bind` and encoder to a specific schema before you can call this function. * you must bind the encoder to a specific schema before you can call this function.
*/ */
def fromRow(row: InternalRow): T def fromRow(row: InternalRow): T
......
...@@ -17,15 +17,11 @@ ...@@ -17,15 +17,11 @@
package org.apache.spark.sql.catalyst.encoders package org.apache.spark.sql.catalyst.encoders
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import scala.reflect.ClassTag import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.reflect.runtime.universe.{typeTag, TypeTag}
import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.{ObjectType, StructType} import org.apache.spark.sql.types.{ObjectType, StructType}
/** /**
...@@ -44,44 +40,6 @@ object ProductEncoder { ...@@ -44,44 +40,6 @@ object ProductEncoder {
val constructExpression = ScalaReflection.constructorFor[T] val constructExpression = ScalaReflection.constructorFor[T]
new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls)) new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls))
} }
}
/**
* A generic encoder for JVM objects.
*
* @param schema The schema after converting `T` to a Spark SQL row.
* @param extractExpressions A set of expressions, one for each top-level field that can be used to
* extract the values from a raw object.
* @param clsTag A classtag for `T`.
*/
case class ClassEncoder[T](
schema: StructType,
extractExpressions: Seq[Expression],
constructExpression: Expression,
clsTag: ClassTag[T])
extends Encoder[T] {
private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions)
private val inputRow = new GenericMutableRow(1)
private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil)
private val dataType = ObjectType(clsTag.runtimeClass)
override def toRow(t: T): InternalRow = {
inputRow(0) = t
extractProjection(inputRow)
}
override def fromRow(row: InternalRow): T = {
constructProjection(row).get(0, dataType).asInstanceOf[T]
}
override def bind(schema: Seq[Attribute]): ClassEncoder[T] = {
val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema))
val analyzedPlan = SimpleAnalyzer.execute(plan)
val resolvedExpression = analyzedPlan.expressions.head.children.head
val boundExpression = BindReferences.bindReference(resolvedExpression, schema)
copy(constructExpression = boundExpression)
}
} }
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.encoders
import scala.collection.Map
import scala.reflect.ClassTag
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
object RowEncoder {
def apply(schema: StructType): ClassEncoder[Row] = {
val cls = classOf[Row]
val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
val extractExpressions = extractorsFor(inputObject, schema)
val constructExpression = constructorFor(schema)
new ClassEncoder[Row](
schema,
extractExpressions.asInstanceOf[CreateStruct].children,
constructExpression,
ClassTag(cls))
}
private def extractorsFor(
inputObject: Expression,
inputType: DataType): Expression = inputType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => inputObject
case TimestampType =>
StaticInvoke(
DateTimeUtils,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil)
case DateType =>
StaticInvoke(
DateTimeUtils,
DateType,
"fromJavaDate",
inputObject :: Nil)
case _: DecimalType =>
StaticInvoke(
Decimal,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil)
case StringType =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil)
case t @ ArrayType(et, _) => et match {
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
NewInstance(
classOf[GenericArrayData],
inputObject :: Nil,
dataType = t)
case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeFor(et))
}
case t @ MapType(kt, vt, valueNullable) =>
val keys =
Invoke(
Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
val convertedKeys = extractorsFor(keys, ArrayType(kt, false))
val values =
Invoke(
Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])),
"toSeq",
ObjectType(classOf[scala.collection.Seq[_]]))
val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable))
NewInstance(
classOf[ArrayBasedMapData],
convertedKeys :: convertedValues :: Nil,
dataType = t)
case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, f.dataType),
extractorsFor(
Invoke(inputObject, "get", externalDataTypeFor(f.dataType), Literal(i) :: Nil),
f.dataType))
}
CreateStruct(convertedFields)
}
private def externalDataTypeFor(dt: DataType): DataType = dt match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => dt
case TimestampType => ObjectType(classOf[java.sql.Timestamp])
case DateType => ObjectType(classOf[java.sql.Date])
case _: DecimalType => ObjectType(classOf[java.math.BigDecimal])
case StringType => ObjectType(classOf[java.lang.String])
case _: ArrayType => ObjectType(classOf[scala.collection.Seq[_]])
case _: MapType => ObjectType(classOf[scala.collection.Map[_, _]])
case _: StructType => ObjectType(classOf[Row])
}
private def constructorFor(schema: StructType): Expression = {
val fields = schema.zipWithIndex.map { case (f, i) =>
val field = BoundReference(i, f.dataType, f.nullable)
If(
IsNull(field),
Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType)
)
}
CreateRow(fields)
}
private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match {
case BooleanType | ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | BinaryType => input
case TimestampType =>
StaticInvoke(
DateTimeUtils,
ObjectType(classOf[java.sql.Timestamp]),
"toJavaTimestamp",
input :: Nil)
case DateType =>
StaticInvoke(
DateTimeUtils,
ObjectType(classOf[java.sql.Date]),
"toJavaDate",
input :: Nil)
case _: DecimalType =>
Invoke(input, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
case StringType =>
Invoke(input, "toString", ObjectType(classOf[String]))
case ArrayType(et, nullable) =>
val arrayData =
Invoke(
MapObjects(constructorFor(_, et), input, et),
"array",
ObjectType(classOf[Array[_]]))
StaticInvoke(
scala.collection.mutable.WrappedArray,
ObjectType(classOf[Seq[_]]),
"make",
arrayData :: Nil)
case MapType(kt, vt, valueNullable) =>
val keyArrayType = ArrayType(kt, false)
val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType)
val valueArrayType = ArrayType(vt, valueNullable)
val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType)
StaticInvoke(
ArrayBasedMapData,
ObjectType(classOf[Map[_, _]]),
"toScalaMap",
keyData :: valueData :: Nil)
case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) =>
If(
Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil),
Literal.create(null, externalDataTypeFor(f.dataType)),
constructorFor(getField(input, i, f.dataType), f.dataType))
}
CreateRow(convertedFields)
}
private def getField(
row: Expression,
ordinal: Int,
dataType: DataType): Expression = dataType match {
case BooleanType =>
Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil)
case ByteType =>
Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil)
case ShortType =>
Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil)
case IntegerType | DateType =>
Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil)
case LongType | TimestampType =>
Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil)
case FloatType =>
Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil)
case DoubleType =>
Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil)
case t: DecimalType =>
Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_)))
case StringType =>
Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil)
case BinaryType =>
Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil)
case CalendarIntervalType =>
Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil)
case t: StructType =>
Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil)
case _: ArrayType =>
Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil)
case _: MapType =>
Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil)
}
}
...@@ -17,12 +17,13 @@ ...@@ -17,12 +17,13 @@
package org.apache.spark.sql.catalyst.expressions package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation} import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation}
import scala.language.existentials import scala.language.existentials
import org.apache.spark.sql.catalyst.{ScalaReflection, InternalRow} import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
...@@ -364,6 +365,10 @@ case class MapObjects( ...@@ -364,6 +365,10 @@ case class MapObjects(
(".numElements()", (i: String) => s".getShort($i)", true) (".numElements()", (i: String) => s".getShort($i)", true)
case ArrayType(BooleanType, _) => case ArrayType(BooleanType, _) =>
(".numElements()", (i: String) => s".getBoolean($i)", true) (".numElements()", (i: String) => s".getBoolean($i)", true)
case ArrayType(StringType, _) =>
(".numElements()", (i: String) => s".getUTF8String($i)", false)
case ArrayType(_: MapType, _) =>
(".numElements()", (i: String) => s".getMap($i)", false)
} }
override def nullable: Boolean = true override def nullable: Boolean = true
...@@ -398,7 +403,7 @@ case class MapObjects( ...@@ -398,7 +403,7 @@ case class MapObjects(
val convertedArray = ctx.freshName("convertedArray") val convertedArray = ctx.freshName("convertedArray")
val loopIndex = ctx.freshName("loopIndex") val loopIndex = ctx.freshName("loopIndex")
val convertedType = ctx.javaType(boundFunction.dataType) val convertedType = ctx.boxedType(boundFunction.dataType)
// Because of the way Java defines nested arrays, we have to handle the syntax specially. // Because of the way Java defines nested arrays, we have to handle the syntax specially.
// Specifically, we have to insert the [$dataLength] in between the type and any extra nested // Specifically, we have to insert the [$dataLength] in between the type and any extra nested
...@@ -434,9 +439,13 @@ case class MapObjects( ...@@ -434,9 +439,13 @@ case class MapObjects(
($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)};
$loopNullCheck $loopNullCheck
${genFunction.code} if ($loopIsNull) {
$convertedArray[$loopIndex] = null;
} else {
${genFunction.code}
$convertedArray[$loopIndex] = ${genFunction.value};
}
$convertedArray[$loopIndex] = ($convertedType)${genFunction.value};
$loopIndex += 1; $loopIndex += 1;
} }
...@@ -446,3 +455,32 @@ case class MapObjects( ...@@ -446,3 +455,32 @@ case class MapObjects(
""" """
} }
} }
case class CreateRow(children: Seq[Expression]) extends Expression {
override def dataType: DataType = ObjectType(classOf[Row])
override def nullable: Boolean = false
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val rowClass = classOf[GenericRow].getName
val values = ctx.freshName("values")
s"""
boolean ${ev.isNull} = false;
final Object[] $values = new Object[${children.size}];
""" +
children.zipWithIndex.map { case (e, i) =>
val eval = e.gen(ctx)
eval.code + s"""
if (${eval.isNull}) {
$values[$i] = null;
} else {
$values[$i] = ${eval.value};
}
"""
}.mkString("\n") +
s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values);"
}
}
...@@ -66,4 +66,8 @@ object ArrayBasedMapData { ...@@ -66,4 +66,8 @@ object ArrayBasedMapData {
def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = { def toScalaMap(keys: Array[Any], values: Array[Any]): Map[Any, Any] = {
keys.zip(values).toMap keys.zip(values).toMap
} }
def toScalaMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
keys.zip(values).toMap
}
} }
...@@ -148,7 +148,7 @@ object RandomDataGenerator { ...@@ -148,7 +148,7 @@ object RandomDataGenerator {
() => BigDecimal.apply( () => BigDecimal.apply(
rand.nextLong() % math.pow(10, precision).toLong, rand.nextLong() % math.pow(10, precision).toLong,
scale, scale,
new MathContext(precision))) new MathContext(precision)).bigDecimal)
case DoubleType => randomNumeric[Double]( case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
...@@ -166,7 +166,7 @@ object RandomDataGenerator { ...@@ -166,7 +166,7 @@ object RandomDataGenerator {
case NullType => Some(() => null) case NullType => Some(() => null)
case ArrayType(elementType, containsNull) => { case ArrayType(elementType, containsNull) => {
forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map {
elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
} }
} }
case MapType(keyType, valueType, valueContainsNull) => { case MapType(keyType, valueType, valueContainsNull) => {
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.encoders
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class RowEncoderSuite extends SparkFunSuite {
private val structOfString = new StructType().add("str", StringType)
private val arrayOfString = ArrayType(StringType)
private val mapOfString = MapType(StringType, StringType)
encodeDecodeTest(
new StructType()
.add("boolean", BooleanType)
.add("byte", ByteType)
.add("short", ShortType)
.add("int", IntegerType)
.add("long", LongType)
.add("float", FloatType)
.add("double", DoubleType)
.add("decimal", DecimalType.SYSTEM_DEFAULT)
.add("string", StringType)
.add("binary", BinaryType)
.add("date", DateType)
.add("timestamp", TimestampType))
encodeDecodeTest(
new StructType()
.add("arrayOfString", arrayOfString)
.add("arrayOfArrayOfString", ArrayType(arrayOfString))
.add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
.add("arrayOfMap", ArrayType(mapOfString))
.add("arrayOfStruct", ArrayType(structOfString)))
encodeDecodeTest(
new StructType()
.add("mapOfIntAndString", MapType(IntegerType, StringType))
.add("mapOfStringAndArray", MapType(StringType, arrayOfString))
.add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType))
.add("mapOfArray", MapType(arrayOfString, arrayOfString))
.add("mapOfStringAndStruct", MapType(StringType, structOfString))
.add("mapOfStructAndString", MapType(structOfString, StringType))
.add("mapOfStruct", MapType(structOfString, structOfString)))
encodeDecodeTest(
new StructType()
.add("structOfString", structOfString)
.add("structOfStructOfString", new StructType().add("struct", structOfString))
.add("structOfArray", new StructType().add("array", arrayOfString))
.add("structOfMap", new StructType().add("map", mapOfString))
.add("structOfArrayAndMap",
new StructType().add("array", arrayOfString).add("map", mapOfString)))
private def encodeDecodeTest(schema: StructType): Unit = {
test(s"encode/decode: ${schema.simpleString}") {
val encoder = RowEncoder(schema)
val inputGenerator = RandomDataGenerator.forType(schema).get
var input: Row = null
try {
for (_ <- 1 to 5) {
input = inputGenerator.apply().asInstanceOf[Row]
val row = encoder.toRow(input)
val convertedBack = encoder.fromRow(row)
assert(input == convertedBack)
}
} catch {
case e: Exception =>
fail(
s"""
|schema: ${schema.simpleString}
|input: ${input}
""".stripMargin, e)
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment