diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index b19538a23f19f740d39d3cb3da690c9bafea778d..1f20e26354d9e0cb19f23d410bcd05401fbb4106 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -245,10 +245,10 @@ object Encoders { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq( + serializer = Seq( EncodeUsingSerializer( BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = + deserializer = DecodeUsingSerializer[T]( BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 59ee41d02f1981ed556a357a4907623093ee2769..6f9fbbbead4744b81a03f596a80f2b158a78ebdf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -155,16 +155,16 @@ object JavaTypeInference { } /** - * Returns an expression that can be used to construct an object of java bean `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * Returns an expression that can be used to deserialize an internal row to an object of java bean + * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. */ - def constructorFor(beanClass: Class[_]): Expression = { - constructorFor(TypeToken.of(beanClass), None) + def deserializerFor(beanClass: Class[_]): Expression = { + deserializerFor(TypeToken.of(beanClass), None) } - private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String): Expression = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) @@ -231,7 +231,7 @@ object JavaTypeInference { }.getOrElse { Invoke( MapObjects( - p => constructorFor(typeToken.getComponentType, Some(p)), + p => deserializerFor(typeToken.getComponentType, Some(p)), getPath, inferDataType(elementType)._1), "array", @@ -243,7 +243,7 @@ object JavaTypeInference { val array = Invoke( MapObjects( - p => constructorFor(et, Some(p)), + p => deserializerFor(et, Some(p)), getPath, inferDataType(et)._1), "array", @@ -259,7 +259,7 @@ object JavaTypeInference { val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), + p => deserializerFor(keyType, Some(p)), Invoke(getPath, "keyArray", ArrayType(keyDataType)), keyDataType), "array", @@ -268,7 +268,7 @@ object JavaTypeInference { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), + p => deserializerFor(valueType, Some(p)), Invoke(getPath, "valueArray", ArrayType(valueDataType)), valueDataType), "array", @@ -288,7 +288,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (_, nullable) = inferDataType(fieldType) - val constructor = constructorFor(fieldType, Some(addToPath(fieldName))) + val constructor = deserializerFor(fieldType, Some(addToPath(fieldName))) val setter = if (nullable) { constructor } else { @@ -313,14 +313,14 @@ object JavaTypeInference { } /** - * Returns expressions for extracting all the fields from the given type. + * Returns an expression for serializing an object of the given type to an internal row. */ - def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] } - private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) @@ -330,7 +330,7 @@ object JavaTypeInference { input :: Nil, dataType = ArrayType(dataType, nullable)) } else { - MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) + MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType)) } } @@ -403,7 +403,7 @@ object JavaTypeInference { inputObject, p.getReadMethod.getName, inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) } else { throw new UnsupportedOperationException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f208401160b8d6e22295954806aebe4871f9372e..d241b8a79bdd3ca43bd22f9c4e57a91db777c7c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -110,8 +110,8 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns an expression that can be used to construct an object of type `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. * @@ -119,14 +119,14 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = { + def deserializerFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil - constructorFor(tpe, None, walkedTypePath) + deserializerFor(tpe, None, walkedTypePath) } - private def constructorFor( + private def deserializerFor( tpe: `Type`, path: Option[Expression], walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { @@ -161,7 +161,7 @@ object ScalaReflection extends ScalaReflection { } /** - * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't * match the encoder's schema. * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type @@ -188,7 +188,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath - WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType)) + WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -272,7 +272,7 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( - p => constructorFor(elementType, Some(p), newTypePath), + p => deserializerFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -286,7 +286,7 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = p => { - val converter = constructorFor(elementType, Some(p), newTypePath) + val converter = deserializerFor(elementType, Some(p), newTypePath) if (nullable) { converter } else { @@ -312,7 +312,7 @@ object ScalaReflection extends ScalaReflection { val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p), walkedTypePath), + p => deserializerFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", @@ -321,7 +321,7 @@ object ScalaReflection extends ScalaReflection { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p), walkedTypePath), + p => deserializerFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", @@ -344,12 +344,12 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { - constructorFor( + deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = constructorFor( + val constructor = deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) @@ -387,7 +387,7 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns expressions for extracting all the fields from the given type. + * Returns an expression for serializing an object of type T to an internal row. * * If the given type is not supported, i.e. there is no encoder can be built for this type, * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain @@ -398,18 +398,18 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil - extractorFor(inputObject, tpe, walkedTypePath) match { + serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - private def extractorFor( + private def serializerFor( inputObject: Expression, tpe: `Type`, walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { @@ -425,7 +425,7 @@ object ScalaReflection extends ScalaReflection { } else { val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + MapObjects(serializerFor(_, elementType, newPath), input, externalDataType) } } @@ -491,7 +491,7 @@ object ScalaReflection extends ScalaReflection { expressions.If( IsNull(unwrapped), expressions.Literal.create(null, silentSchemaFor(optType).dataType), - extractorFor(unwrapped, optType, newPath)) + serializerFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => @@ -500,7 +500,7 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 918233ddcdaf5cdc22f490708ed45c7362619e19..1c712fde2677c449c09ef3257db3f19c259c81cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -51,8 +51,8 @@ object ExpressionEncoder { val flat = !classOf[Product].isAssignableFrom(cls) val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) - val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) - val fromRowExpression = ScalaReflection.constructorFor[T] + val serializer = ScalaReflection.serializerFor[T](inputObject) + val deserializer = ScalaReflection.deserializerFor[T] val schema = ScalaReflection.schemaFor[T] match { case ScalaReflection.Schema(s: StructType, _) => s @@ -62,8 +62,8 @@ object ExpressionEncoder { new ExpressionEncoder[T]( schema, flat, - toRowExpression.flatten, - fromRowExpression, + serializer.flatten, + deserializer, ClassTag[T](cls)) } @@ -72,14 +72,14 @@ object ExpressionEncoder { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) - val toRowExpression = JavaTypeInference.extractorsFor(beanClass) - val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + val serializer = JavaTypeInference.serializerFor(beanClass) + val deserializer = JavaTypeInference.deserializerFor(beanClass) new ExpressionEncoder[T]( schema.asInstanceOf[StructType], flat = false, - toRowExpression.flatten, - fromRowExpression, + serializer.flatten, + deserializer, ClassTag[T](beanClass)) } @@ -103,9 +103,9 @@ object ExpressionEncoder { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val toRowExpressions = encoders.map { - case e if e.flat => e.toRowExpressions.head - case other => CreateStruct(other.toRowExpressions) + val serializer = encoders.map { + case e if e.flat => e.serializer.head + case other => CreateStruct(other.serializer) }.zipWithIndex.map { case (expr, index) => expr.transformUp { case BoundReference(0, t, _) => @@ -116,14 +116,14 @@ object ExpressionEncoder { } } - val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) => + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { - enc.fromRowExpression.transform { + enc.deserializer.transform { case b: BoundReference => b.copy(ordinal = index) } } else { val input = BoundReference(index, enc.schema, nullable = true) - enc.fromRowExpression.transformUp { + enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) @@ -132,14 +132,14 @@ object ExpressionEncoder { } } - val fromRowExpression = - NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false) + val deserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, flat = false, - toRowExpressions, - fromRowExpression, + serializer, + deserializer, ClassTag(cls)) } @@ -174,29 +174,29 @@ object ExpressionEncoder { * A generic encoder for JVM objects. * * @param schema The schema after converting `T` to a Spark SQL row. - * @param toRowExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param fromRowExpression An expression that will construct an object given an [[InternalRow]]. + * @param serializer A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param deserializer An expression that will construct an object given an [[InternalRow]]. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( schema: StructType, flat: Boolean, - toRowExpressions: Seq[Expression], - fromRowExpression: Expression, + serializer: Seq[Expression], + deserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(toRowExpressions.size == 1) + if (flat) require(serializer.size == 1) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) + private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @transient private lazy val inputRow = new GenericMutableRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) + private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) /** * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns @@ -212,7 +212,7 @@ case class ExpressionEncoder[T]( * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. */ - def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map { + def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map { case (_, ne: NamedExpression) => ne.newInstance() case (name, e) => Alias(e, name)() } @@ -228,7 +228,7 @@ case class ExpressionEncoder[T]( } catch { case e: Exception => throw new RuntimeException( - s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e) + s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e) } /** @@ -240,7 +240,7 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e) + throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e) } /** @@ -249,7 +249,7 @@ case class ExpressionEncoder[T]( * has not been done already in places where we plan to do later composition of encoders. */ def assertUnresolved(): Unit = { - (fromRowExpression +: toRowExpressions).foreach(_.foreach { + (deserializer +: serializer).foreach(_.foreach { case a: AttributeReference if a.name != "loopVar" => sys.error(s"Unresolved encoder expected, but $a was found.") case _ => @@ -257,7 +257,7 @@ case class ExpressionEncoder[T]( } /** - * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce + * Validates `deserializer` to make sure it can be resolved by given schema, and produce * friendly error messages to explain why it fails to resolve if there is something wrong. */ def validate(schema: Seq[Attribute]): Unit = { @@ -271,7 +271,7 @@ case class ExpressionEncoder[T]( // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all // `BoundReference`, make sure their ordinals are all valid. var maxOrdinal = -1 - fromRowExpression.foreach { + deserializer.foreach { case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal case _ => } @@ -285,7 +285,7 @@ case class ExpressionEncoder[T]( // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after // we resolve the `fromRowExpression`. val resolved = SimpleAnalyzer.resolveExpression( - fromRowExpression, + deserializer, LocalRelation(schema), throws = true) @@ -312,42 +312,39 @@ case class ExpressionEncoder[T]( } /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. + * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. */ def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer( - fromRowExpression, schema) + val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema) // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check // analysis, go through optimizer, etc. - val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema)) + val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head) + copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) } /** - * Returns a copy of this encoder where the expressions used to construct an object from an input - * row have been bound to the ordinals of the given schema. Note that you need to first call - * resolve before bind. + * Returns a copy of this encoder where the `deserializer` has been bound to the + * ordinals of the given schema. Note that you need to first call resolve before bind. */ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema)) + copy(deserializer = BindReferences.bindReference(deserializer, schema)) } /** * Returns a new encoder with input columns shifted by `delta` ordinals */ def shift(delta: Int): ExpressionEncoder[T] = { - copy(fromRowExpression = fromRowExpression transform { + copy(deserializer = deserializer transform { case r: BoundReference => r.copy(ordinal = r.ordinal + delta) }) } - protected val attrs = toRowExpressions.flatMap(_.collect { + protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" case b: BoundReference => s"[${b.ordinal}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 30f56d8c2f969641831c4b940bf6c6fdcc3fc10c..a8397aa5e5c260acb92d94f20cacd59dcb29ca0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -36,23 +36,23 @@ object RowEncoder { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) // We use an If expression to wrap extractorsFor result of StructType - val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue - val constructExpression = constructorFor(schema) + val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, flat = false, - extractExpressions.asInstanceOf[CreateStruct].children, - constructExpression, + serializer.asInstanceOf[CreateStruct].children, + deserializer, ClassTag(cls)) } - private def extractorsFor( + private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject - case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType) + case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) case udt: UserDefinedType[_] => val obj = NewInstance( @@ -95,7 +95,7 @@ object RowEncoder { classOf[GenericArrayData], inputObject :: Nil, dataType = t) - case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et)) + case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et)) } case t @ MapType(kt, vt, valueNullable) => @@ -104,14 +104,14 @@ object RowEncoder { Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = extractorsFor(keys, ArrayType(kt, false)) + val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable)) + val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( classOf[ArrayBasedMapData], @@ -128,7 +128,7 @@ object RowEncoder { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), - extractorsFor( + serializerFor( Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), f.dataType)) } @@ -166,7 +166,7 @@ object RowEncoder { case _: NullType => ObjectType(classOf[java.lang.Object]) } - private def constructorFor(schema: StructType): Expression = { + private def deserializerFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => val dt = f.dataType match { case p: PythonUserDefinedType => p.sqlType @@ -176,13 +176,13 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(dt)), - constructorFor(field) + deserializerFor(field) ) } CreateExternalRow(fields, schema) } - private def constructorFor(input: Expression): Expression = input.dataType match { + private def deserializerFor(input: Expression): Expression = input.dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => input @@ -216,7 +216,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_), input, et), + MapObjects(deserializerFor(_), input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( @@ -227,10 +227,10 @@ object RowEncoder { case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) + val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) + val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( ArrayBasedMapData.getClass, @@ -243,7 +243,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetStructField(input, i))) + deserializerFor(GetStructField(input, i))) } If(IsNull(input), Literal.create(null, externalDataTypeFor(input.dataType)), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index da7f81c7854613064e356c22db6b931b50f71814..058fb6bff1c6e06e9dedb00a108e95f1ecd06508 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -71,7 +71,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - encoderFor[T].fromRowExpression, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, child) } @@ -98,7 +98,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - encoderFor[T].fromRowExpression, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, child) } @@ -133,8 +133,8 @@ object MapGroups { child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], - encoderFor[K].fromRowExpression, - encoderFor[T].fromRowExpression, + encoderFor[K].deserializer, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, @@ -178,9 +178,9 @@ object CoGroup { CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], - encoderFor[Key].fromRowExpression, - encoderFor[Left].fromRowExpression, - encoderFor[Right].fromRowExpression, + encoderFor[Key].deserializer, + encoderFor[Left].deserializer, + encoderFor[Right].deserializer, encoderFor[Result].namedExpressions, leftGroup, rightGroup, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index dd31050bb55787769f08550df559ad338f339338..5ca5a72512a29ae0f0dfa0423d01b3f01d8ed970 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -248,10 +248,10 @@ class ScalaReflectionSuite extends SparkFunSuite { Seq( ("mirror", () => mirror), ("dataTypeFor", () => dataTypeFor[ComplexData]), - ("constructorFor", () => constructorFor[ComplexData]), + ("constructorFor", () => deserializerFor[ComplexData]), ("extractorsFor", { val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false) - () => extractorsFor[ComplexData](inputObject) + () => serializerFor[ComplexData](inputObject) }), ("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])), ("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f6583bfe4276852d711d5e771cd33a690c946104..18752014ea908c6a262ac4b4766687ce1ba7821e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -315,7 +315,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() val inputPlan = LocalRelation(attr) val plan = - Project(Alias(encoder.fromRowExpression, "obj")() :: Nil, + Project(Alias(encoder.deserializer, "obj")() :: Nil, Project(encoder.namedExpressions, inputPlan)) assertAnalysisSuccess(plan) @@ -360,7 +360,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { |${encoder.schema.treeString} | |fromRow Expressions: - |${boundEncoder.fromRowExpression.treeString} + |${boundEncoder.deserializer.treeString} """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 7ff4ffcaecd493190b628ee4e588ed117d889af4..854a662cc4d3dc43643eba6d2f25f1b2d198a1ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -90,7 +90,7 @@ abstract class QueryTest extends PlanTest { s""" |Exception collecting dataset as objects |${ds.resolvedTEncoder} - |${ds.resolvedTEncoder.fromRowExpression.treeString} + |${ds.resolvedTEncoder.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -109,7 +109,7 @@ abstract class QueryTest extends PlanTest { fail( s"""Decoded objects do not match expected objects: |$comparision - |${ds.resolvedTEncoder.fromRowExpression.treeString} + |${ds.resolvedTEncoder.deserializer.treeString} """.stripMargin) } }