diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b59eb12419c45195a6b3f16cbac353ad1bb77017..cb228cf52b43368e01020d99003d04ebc647a121 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans._ @@ -457,25 +458,34 @@ class Analyzer( // When resolve `SortOrder`s in Sort based on child, don't report errors as // we still have chance to resolve it based on its descendants case s @ Sort(ordering, global, child) if child.resolved && !s.resolved => - val newOrdering = resolveSortOrders(ordering, child, throws = false) + val newOrdering = + ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) Sort(newOrdering, global, child) // A special case for Generate, because the output of Generate should not be resolved by // ResolveReferences. Attributes in the output will be resolved by ResolveGenerate. case g @ Generate(generator, join, outer, qualifier, output, child) if child.resolved && !generator.resolved => - val newG = generator transformUp { - case u @ UnresolvedAttribute(nameParts) => - withPosition(u) { child.resolve(nameParts, resolver).getOrElse(u) } - case UnresolvedExtractValue(child, fieldExpr) => - ExtractValue(child, fieldExpr, resolver) - } + val newG = resolveExpression(generator, child, throws = true) if (newG.fastEquals(generator)) { g } else { Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } + // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator + // should be resolved by their corresponding attributes instead of children's output. + case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) => + val deserializerToAttributes = o.deserializers.map { + case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes + }.toMap + + o.transformExpressions { + case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes => + resolveDeserializer(expr, attributes) + }.getOrElse(expr) + } + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressionsUp { @@ -490,6 +500,32 @@ class Analyzer( } } + private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = { + exprs.exists { expr => + !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined + } + } + + def resolveDeserializer( + deserializer: Expression, + attributes: Seq[Attribute]): Expression = { + val unbound = deserializer transform { + case b: BoundReference => attributes(b.ordinal) + } + + resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { + case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => + val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(Literal.fromObject(outer))) + } + } + def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { case a: Alias => Alias(a.child, a.name)() @@ -508,23 +544,20 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } - private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { - ordering.map { order => - // Resolve SortOrder in one round. - // If throws == false or the desired attribute doesn't exist - // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. - // Else, throw exception. - try { - val newOrder = order transformUp { - case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).getOrElse(u) - case UnresolvedExtractValue(child, fieldName) if child.resolved => - ExtractValue(child, fieldName, resolver) - } - newOrder.asInstanceOf[SortOrder] - } catch { - case a: AnalysisException if !throws => order + private def resolveExpression(expr: Expression, plan: LogicalPlan, throws: Boolean = false) = { + // Resolve expression in one round. + // If throws == false or the desired attribute doesn't exist + // (like try to resolve `a.b` but `a` doesn't exist), fail and return the origin one. + // Else, throw exception. + try { + expr transformUp { + case u @ UnresolvedAttribute(nameParts) => + withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) } + } catch { + case a: AnalysisException if !throws => expr } } @@ -619,7 +652,8 @@ class Analyzer( ordering: Seq[SortOrder], plan: LogicalPlan, child: LogicalPlan): (Seq[SortOrder], Seq[Attribute]) = { - val newOrdering = resolveSortOrders(ordering, child, throws = false) + val newOrdering = + ordering.map(order => resolveExpression(order, child).asInstanceOf[SortOrder]) // Construct a set that contains all of the attributes that we need to evaluate the // ordering. val requiredAttributes = AttributeSet(newOrdering).filter(_.resolved) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 64832dc114e67ee4821c09ee8f75fc99956628b5..58f6d0eb9e929c56abf30a210ee19c9848ef677d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -50,7 +50,7 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(typeTag[T].tpe) val flat = !classOf[Product].isAssignableFrom(cls) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) val fromRowExpression = ScalaReflection.constructorFor[T] @@ -257,12 +257,10 @@ case class ExpressionEncoder[T]( } /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. + * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce + * friendly error messages to explain why it fails to resolve if there is something wrong. */ - def resolve( - schema: Seq[Attribute], - outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + def validate(schema: Seq[Attribute]): Unit = { def fail(st: StructType, maxOrdinal: Int): Unit = { throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + "but failed as the number of fields does not line up.\n" + @@ -270,6 +268,8 @@ case class ExpressionEncoder[T]( " - Target schema: " + this.schema.simpleString) } + // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all + // `BoundReference`, make sure their ordinals are all valid. var maxOrdinal = -1 fromRowExpression.foreach { case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal @@ -279,6 +279,10 @@ case class ExpressionEncoder[T]( fail(StructType.fromAttributes(schema), maxOrdinal) } + // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of + // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. + // Note that, `BoundReference` contains the expected type, but here we need the actual type, so + // we unbound it by the given `schema` and propagate the actual type to `GetStructField`. val unbound = fromRowExpression transform { case b: BoundReference => schema(b.ordinal) } @@ -299,28 +303,24 @@ case class ExpressionEncoder[T]( fail(schema, maxOrdinal) } } + } - val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema)) + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the + * given schema. + */ + def resolve( + schema: Seq[Attribute], + outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { + val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer( + fromRowExpression, schema) + + // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check + // analysis, go through optimizer, etc. + val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - val optimizedPlan = SimplifyCasts(analyzedPlan) - - // In order to construct instances of inner classes (for example those declared in a REPL cell), - // we need an instance of the outer scope. This rule substitues those outer objects into - // expressions that are missing them by looking up the name in the SQLContexts `outerScopes` - // registry. - copy(fromRowExpression = optimizedPlan.expressions.head.children.head transform { - case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass => - val outer = outerScopes.get(n.cls.getDeclaringClass.getName) - if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without access " + - s"to the scope that this class was defined in. " + "" + - "Try moving this class out of its parent class.") - } - - n.copy(outerPointer = Some(Literal.fromObject(outer))) - }) + copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 89d40b3b2c1419d4804e2a406c0d7d89bfc36ceb..d8f755a39c7eaffe836028710ed91a218020c280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -154,7 +154,7 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable)) + constructorFor(field) ) } CreateExternalRow(fields) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a1ac93073916cba933732b5c8dec1ac235bbee58..902e18081bddf429a53092a313e1fb9ba136477f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,10 +119,13 @@ object SamplePushDown extends Rule[LogicalPlan] { */ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case m @ MapPartitions(_, input, _, child: ObjectOperator) - if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType => + case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => val childWithoutSerialization = child.withObjectOutput - m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization) + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 760348052739c11041c727f0cc9b1e223b73344e..3f97662957b8edd3ea801ae8d6d78b26f21e1fa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{ObjectType, StructType} /** * A trait for logical operators that apply user defined functions to domain objects. @@ -30,6 +30,15 @@ trait ObjectOperator extends LogicalPlan { /** The serializer that is used to produce the output of this operator. */ def serializer: Seq[NamedExpression] + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + /** + * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects. + * It must also provide the attributes that are available during the resolution of each + * deserializer. + */ + def deserializers: Seq[(Expression, Seq[Attribute])] + /** * The object type that is produced by the user defined function. Note that the return type here * is the same whether or not the operator is output serialized data. @@ -44,13 +53,13 @@ trait ObjectOperator extends LogicalPlan { def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { this } else { - withNewSerializer(outputObject) + withNewSerializer(outputObject :: Nil) } /** Returns a copy of this operator with a different serializer. */ - def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy { + def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy { productIterator.map { - case c if c == serializer => newSerializer :: Nil + case c if c == serializer => newSerializer case other: AnyRef => other }.toArray } @@ -70,15 +79,16 @@ object MapPartitions { /** * A relation produced by applying `func` to each partition of the `child`. - * @param input used to extract the input to `func` from an input row. + * + * @param deserializer used to extract the input to `func` from an input row. * @param serializer use to serialize the output of `func`. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `AppendColumn` nodes. */ @@ -97,16 +107,21 @@ object AppendColumns { /** * A relation produced by applying `func` to each partition of the `child`, concatenating the * resulting columns at the end of the input row. - * @param input used to extract the input to `func` from an input row. + * + * @param deserializer used to extract the input to `func` from an input row. * @param serializer use to serialize the output of `func`. */ case class AppendColumns( func: Any => Any, - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = child.output ++ newColumns + def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) + + override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `MapGroups` nodes. */ @@ -114,6 +129,7 @@ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], @@ -121,6 +137,7 @@ object MapGroups { encoderFor[T].fromRowExpression, encoderFor[U].namedExpressions, groupingAttributes, + dataAttributes, child) } } @@ -129,19 +146,22 @@ object MapGroups { * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. * Func is invoked with an object representation of the grouping key an iterator containing the * object representation of all the rows with that key. - * @param keyObject used to extract the key object for each group. - * @param input used to extract the items in the iterator from an input row. + * + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. * @param serializer use to serialize the output of `func`. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - input: Expression, + keyDeserializer: Expression, + valueDeserializer: Expression, serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: LogicalPlan) extends UnaryNode with ObjectOperator { - def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = + Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) } /** Factory for constructing new `CoGroup` nodes. */ @@ -150,8 +170,12 @@ object CoGroup { func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftData: Seq[Attribute], + rightData: Seq[Attribute], left: LogicalPlan, right: LogicalPlan): CoGroup = { + require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) + CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], encoderFor[Key].fromRowExpression, @@ -160,6 +184,8 @@ object CoGroup { encoderFor[Result].namedExpressions, leftGroup, rightGroup, + leftData, + rightData, left, right) } @@ -171,15 +197,21 @@ object CoGroup { */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - leftObject: Expression, - rightObject: Expression, + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: LogicalPlan, right: LogicalPlan) extends BinaryNode with ObjectOperator { + override def producedAttributes: AttributeSet = outputSet - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def deserializers: Seq[(Expression, Seq[Attribute])] = + // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve + // the `keyDeserializer` based on either of them, here we pick the left one. + Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index bc36a55ae0ea2cdd873e7e5e1b96077190eac630..92a68a4dba91537e6da0135c697d908a4553ac5a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -127,7 +127,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.long, 'c.int) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct<a:string,b:bigint,c:int>\n" + @@ -136,7 +136,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct<a:string> to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct<a:string>\n" + @@ -149,7 +149,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" + @@ -158,7 +158,7 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long)) - assert(intercept[AnalysisException](encoder.resolve(attrs, null)).message == + assert(intercept[AnalysisException](encoder.validate(attrs)).message == "Try to map struct<x:bigint> to Tuple2, " + "but failed as the number of fields does not line up.\n" + " - Input schema: struct<a:string,b:struct<x:bigint>>\n" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 88c558d80a79a8551a15cdaab7eeee03d4559d62..e00060f9b6aff435a3f5d1d3ba1caf6741f578e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -19,13 +19,10 @@ package org.apache.spark.sql.catalyst.encoders import java.sql.{Date, Timestamp} import java.util.Arrays -import java.util.concurrent.ConcurrentMap import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag -import com.google.common.collect.MapMaker - import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} @@ -78,7 +75,7 @@ class JavaSerializable(val value: Int) extends Serializable { } class ExpressionEncoderSuite extends SparkFunSuite { - OuterScopes.outerScopes.put(getClass.getName, this) + OuterScopes.addOuterScope(this) implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f182270a08729ce96df94b29b2f613798efa4e51..378763268acc6f80bb4468ffc2b83ee6022f6a7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -74,6 +74,7 @@ class Dataset[T] private[sql]( * same object type (that will be possibly resolved to a different schema). */ private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + unresolvedTEncoder.validate(logicalPlan.output) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = @@ -85,7 +86,7 @@ class Dataset[T] private[sql]( */ private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) - private implicit def classTag = resolvedTEncoder.clsTag + private implicit def classTag = unresolvedTEncoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = this(sqlContext, new QueryExecution(sqlContext, plan), encoder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index b3f8284364782026eb4abf5b88d8eb0f38ac5c9e..c0e28f2dc5bd6d447cc3a2449638c5d1af926413 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -116,6 +116,7 @@ class GroupedDataset[K, V] private[sql]( MapGroups( f, groupingAttributes, + dataAttributes, logicalPlan)) } @@ -310,6 +311,8 @@ class GroupedDataset[K, V] private[sql]( f, this.groupingAttributes, other.groupingAttributes, + this.dataAttributes, + other.dataAttributes, this.logicalPlan, other.logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 9293e55141757eaa2f19220df725ad604b1c18fd..830bb011beab4750a07b4a399210b71dcd217104 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -306,11 +306,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.MapPartitions(f, in, out, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil - case logical.MapGroups(f, key, in, out, grouping, child) => - execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil - case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) => + case logical.MapGroups(f, key, in, out, grouping, data, child) => + execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil + case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) => execution.CoGroup( - f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil + f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, + planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2acca1743cbb905dc7e637b0f0d18f1a912dfdbb..582dda8603f4e4dde9ebf9304a4f18d09dd61a0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -53,14 +53,14 @@ trait ObjectOperator extends SparkPlan { */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with ObjectOperator { override def output: Seq[Attribute] = serializer.map(_.toAttribute) override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(input, child.output) + val getObject = generateToObject(deserializer, child.output) val outputObject = generateToRow(serializer) func(iter.map(getObject)).map(outputObject) } @@ -72,7 +72,7 @@ case class MapPartitions( */ case class AppendColumns( func: Any => Any, - input: Expression, + deserializer: Expression, serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with ObjectOperator { @@ -82,7 +82,7 @@ case class AppendColumns( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(input, child.output) + val getObject = generateToObject(deserializer, child.output) val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) val outputObject = generateToRow(serializer) @@ -103,10 +103,11 @@ case class AppendColumns( */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - input: Expression, + keyDeserializer: Expression, + valueDeserializer: Expression, serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], child: SparkPlan) extends UnaryNode with ObjectOperator { override def output: Seq[Attribute] = serializer.map(_.toAttribute) @@ -121,8 +122,8 @@ case class MapGroups( child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val getKey = generateToObject(keyObject, groupingAttributes) - val getValue = generateToObject(input, child.output) + val getKey = generateToObject(keyDeserializer, groupingAttributes) + val getValue = generateToObject(valueDeserializer, dataAttributes) val outputObject = generateToRow(serializer) grouped.flatMap { case (key, rowIter) => @@ -142,12 +143,14 @@ case class MapGroups( */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], - keyObject: Expression, - leftObject: Expression, - rightObject: Expression, + keyDeserializer: Expression, + leftDeserializer: Expression, + rightDeserializer: Expression, serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: SparkPlan, right: SparkPlan) extends BinaryNode with ObjectOperator { @@ -164,9 +167,9 @@ case class CoGroup( val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val getKey = generateToObject(keyObject, leftGroup) - val getLeft = generateToObject(leftObject, left.output) - val getRight = generateToObject(rightObject, right.output) + val getKey = generateToObject(keyDeserializer, leftGroup) + val getLeft = generateToObject(leftDeserializer, leftAttr) + val getRight = generateToObject(rightDeserializer, rightAttr) val outputObject = generateToRow(serializer) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index b69bb21db532b14082e19a0efc26bda92077ab07..374f4320a92391a870f498430056bc0ea87e4002 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps +import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -45,13 +46,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } - test("SPARK-12404: Datatype Helper Serializablity") { val ds = sparkContext.parallelize(( - new Timestamp(0), - new Date(0), - java.math.BigDecimal.valueOf(1), - scala.math.BigDecimal(1)) :: Nil).toDS() + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() ds.collect() } @@ -523,7 +523,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("verify mismatching field names fail with a good error") { val ds = Seq(ClassData("a", 1)).toDS() val e = intercept[AnalysisException] { - ds.as[ClassData2].collect() + ds.as[ClassData2] } assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) } @@ -567,6 +567,58 @@ class DatasetSuite extends QueryTest with SharedSQLContext { checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) checkAnswer(ds1.toDF(), Row(Row(null))) } + + test("support inner class in Dataset") { + val outer = new OuterClass + OuterScopes.addOuterScope(outer) + val ds = Seq(outer.InnerClass("1"), outer.InnerClass("2")).toDS() + checkAnswer(ds.map(_.a), "1", "2") + } + + test("grouping key and grouped value has field with same name") { + val ds = Seq(ClassData("a", 1), ClassData("a", 2)).toDS() + val agged = ds.groupBy(d => ClassNullableData(d.a, null)).mapGroups { + case (key, values) => key.a + values.map(_.b).sum + } + + checkAnswer(agged, "a3") + } + + test("cogroup's left and right side has field with same name") { + val left = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + val right = Seq(ClassNullableData("a", 3), ClassNullableData("b", 4)).toDS() + val cogrouped = left.groupBy(_.a).cogroup(right.groupBy(_.a)) { + case (key, lData, rData) => Iterator(key + lData.map(_.b).sum + rData.map(_.b.toInt).sum) + } + + checkAnswer(cogrouped, "a13", "b24") + } + + test("give nice error message when the real number of fields doesn't match encoder schema") { + val ds = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + + val message = intercept[AnalysisException] { + ds.as[(String, Int, Long)] + }.message + assert(message == + "Try to map struct<a:string,b:int> to Tuple3, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:int>\n" + + " - Target schema: struct<_1:string,_2:int,_3:bigint>") + + val message2 = intercept[AnalysisException] { + ds.as[Tuple1[String]] + }.message + assert(message2 == + "Try to map struct<a:string,b:int> to Tuple1, " + + "but failed as the number of fields does not line up.\n" + + " - Input schema: struct<a:string,b:int>\n" + + " - Target schema: struct<_1:string>") + } +} + +class OuterClass extends Serializable { + case class InnerClass(a: String) } case class ClassData(a: String, b: Int)