diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8a33af8207350563bcdd91e2dc982e9d002a5449..dadea6b54a946e647d33f657211f0befd04ed4f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1214,6 +1214,10 @@ object CleanupAliases extends Rule[LogicalPlan] { Window(projectList, cleanedWindowExprs, partitionSpec.map(trimAliases), orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child) + // Operators that operate on objects should only have expressions from encoders, which should + // never have extra aliases. + case o: ObjectOperator => o + case other => var stop = false other transformExpressionsDown { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index fc0e87aa68ed4626ae0562e46cf3902dbe92ca5d..79eebbf9b1ec45c338558ae075605a6344d63290 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -160,6 +160,7 @@ abstract class Star extends LeafExpression with NamedExpression { override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] @@ -246,6 +247,8 @@ case class MultiAlias(child: Expression, names: Seq[String]) override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") + override lazy val resolved = false override def toString: String = s"$child AS $names" @@ -259,6 +262,7 @@ case class MultiAlias(child: Expression, names: Seq[String]) * @param expressions Expressions to expand. */ case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with Unevaluable { + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = expressions override def toString: String = expressions.mkString("ResolvedStar(", ", ", ")") } @@ -298,6 +302,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def name: String = throw new UnresolvedException(this, "name") + override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance") override lazy val resolved = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 05f746e72b498fb062ec935bf963594711fe560b..64832dc114e67ee4821c09ee8f75fc99956628b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -207,6 +207,16 @@ case class ExpressionEncoder[T]( resolve(attrs, OuterScopes.outerScopes).bind(attrs) } + + /** + * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form + * of this object. + */ + def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map { + case (_, ne: NamedExpression) => ne.newInstance() + case (name, e) => Alias(e, name)() + } + /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 7293d5d4472afb372de486d6b1b3974780290b65..c94b2c0e270b653a55e739e20a2c6fb449a2f891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal, $dataType]" + override def toString: String = s"input[$ordinal, ${dataType.simpleString}]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { @@ -66,6 +66,8 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def exprId: ExprId = throw new UnsupportedOperationException + override def newInstance(): NamedExpression = this + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index eee708cb02f9dd3a2b8f017955d114fa4e02f763..b6d7a7f5e8d01389d5b7d96bbfd5e10b7fe83956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -79,6 +79,9 @@ trait NamedExpression extends Expression { /** Returns the metadata when an expression is a reference to another expression with metadata. */ def metadata: Metadata = Metadata.empty + /** Returns a copy of this expression with a new `exprId`. */ + def newInstance(): NamedExpression + protected def typeSuffix = if (resolved) { dataType match { @@ -144,6 +147,9 @@ case class Alias(child: Expression, name: String)( } } + def newInstance(): NamedExpression = + Alias(child, name)(qualifiers = qualifiers, explicitMetadata = explicitMetadata) + override def toAttribute: Attribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index c0c3e6e891669f8e98a810915d51b001f852d77c..8385f7e1da591789e2b17a2dbba709225dfbcc17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -172,6 +172,8 @@ case class Invoke( $objNullCheck """ } + + override def toString: String = s"$targetObject.$functionName" } object NewInstance { @@ -253,6 +255,8 @@ case class NewInstance( """ } } + + override def toString: String = s"newInstance($cls)" } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 487431f8925a376da52c472edce2a69a13e66143..cc3371c08fac4a4ed1fbbbc66bcf6bbfaced052c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -67,7 +67,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { RemoveDispensableExpressions, SimplifyFilters, SimplifyCasts, - SimplifyCaseConversionExpressions) :: + SimplifyCaseConversionExpressions, + EliminateSerialization) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), @@ -96,6 +97,19 @@ object SamplePushDown extends Rule[LogicalPlan] { } } +/** + * Removes cases where we are unnecessarily going between the object and serialized (InternalRow) + * representation of data item. For example back to back map operations. + */ +object EliminateSerialization extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case m @ MapPartitions(_, input, _, child: ObjectOperator) + if !input.isInstanceOf[Attribute] && m.input.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy(input = childWithoutSerialization.output.head, child = childWithoutSerialization) + } +} + /** * Pushes certain operations to both sides of a Union, Intersect or Except operator. * Operations that are safe to pushdown are listed as follows. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 64957db6b40135acdb7c06221fc2ad3a0ce06902..2a1b1b131d813697f7b9df65a46d2a27941e2bea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -480,120 +478,3 @@ case object OneRowRelation extends LeafNode { */ override def statistics: Statistics = Statistics(sizeInBytes = 1) } - -/** - * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are - * used respectively to decode/encode from the JVM object representation expected by `func.` - */ -case class MapPartitions[T, U]( - func: Iterator[T] => Iterator[U], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet -} - -/** Factory for constructing new `AppendColumn` nodes. */ -object AppendColumns { - def apply[T, U : Encoder]( - func: T => U, - tEncoder: ExpressionEncoder[T], - child: LogicalPlan): AppendColumns[T, U] = { - val attrs = encoderFor[U].schema.toAttributes - new AppendColumns[T, U](func, tEncoder, encoderFor[U], attrs, child) - } -} - -/** - * A relation produced by applying `func` to each partition of the `child`, concatenating the - * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to - * decode/encode from the JVM object representation expected by `func.` - */ -case class AppendColumns[T, U]( - func: T => U, - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - newColumns: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output ++ newColumns - override def producedAttributes: AttributeSet = AttributeSet(newColumns) -} - -/** Factory for constructing new `MapGroups` nodes. */ -object MapGroups { - def apply[K, T, U : Encoder]( - func: (K, Iterator[T]) => TraversableOnce[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - groupingAttributes: Seq[Attribute], - child: LogicalPlan): MapGroups[K, T, U] = { - new MapGroups( - func, - kEncoder, - tEncoder, - encoderFor[U], - groupingAttributes, - encoderFor[U].schema.toAttributes, - child) - } -} - -/** - * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. - * Func is invoked with an object representation of the grouping key an iterator containing the - * object representation of all the rows with that key. - */ -case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => TraversableOnce[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - groupingAttributes: Seq[Attribute], - output: Seq[Attribute], - child: LogicalPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet -} - -/** Factory for constructing new `CoGroup` nodes. */ -object CoGroup { - def apply[Key, Left, Right, Result : Encoder]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], - keyEnc: ExpressionEncoder[Key], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: LogicalPlan, - right: LogicalPlan): CoGroup[Key, Left, Right, Result] = { - CoGroup( - func, - keyEnc, - leftEnc, - rightEnc, - encoderFor[Result], - encoderFor[Result].schema.toAttributes, - leftGroup, - rightGroup, - left, - right) - } -} - -/** - * A relation produced by applying `func` to each grouping key and associated values from left and - * right children. - */ -case class CoGroup[Key, Left, Right, Result]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], - keyEnc: ExpressionEncoder[Key], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - resultEnc: ExpressionEncoder[Result], - output: Seq[Attribute], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: LogicalPlan, - right: LogicalPlan) extends BinaryNode { - override def producedAttributes: AttributeSet = outputSet -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala new file mode 100644 index 0000000000000000000000000000000000000000..760348052739c11041c727f0cc9b1e223b73344e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.ObjectType + +/** + * A trait for logical operators that apply user defined functions to domain objects. + */ +trait ObjectOperator extends LogicalPlan { + + /** The serializer that is used to produce the output of this operator. */ + def serializer: Seq[NamedExpression] + + /** + * The object type that is produced by the user defined function. Note that the return type here + * is the same whether or not the operator is output serialized data. + */ + def outputObject: NamedExpression = + Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")() + + /** + * Returns a copy of this operator that will produce an object instead of an encoded row. + * Used in the optimizer when transforming plans to remove unneeded serialization. + */ + def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { + this + } else { + withNewSerializer(outputObject) + } + + /** Returns a copy of this operator with a different serializer. */ + def withNewSerializer(newSerializer: NamedExpression): LogicalPlan = makeCopy { + productIterator.map { + case c if c == serializer => newSerializer :: Nil + case other: AnyRef => other + }.toArray + } +} + +object MapPartitions { + def apply[T : Encoder, U : Encoder]( + func: Iterator[T] => Iterator[U], + child: LogicalPlan): MapPartitions = { + MapPartitions( + func.asInstanceOf[Iterator[Any] => Iterator[Any]], + encoderFor[T].fromRowExpression, + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`. + * @param input used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapPartitions( + func: Iterator[Any] => Iterator[Any], + input: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) +} + +/** Factory for constructing new `AppendColumn` nodes. */ +object AppendColumns { + def apply[T : Encoder, U : Encoder]( + func: T => U, + child: LogicalPlan): AppendColumns = { + new AppendColumns( + func.asInstanceOf[Any => Any], + encoderFor[T].fromRowExpression, + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`, concatenating the + * resulting columns at the end of the input row. + * @param input used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class AppendColumns( + func: Any => Any, + input: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = child.output ++ newColumns + def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) +} + +/** Factory for constructing new `MapGroups` nodes. */ +object MapGroups { + def apply[K : Encoder, T : Encoder, U : Encoder]( + func: (K, Iterator[T]) => TraversableOnce[U], + groupingAttributes: Seq[Attribute], + child: LogicalPlan): MapGroups = { + new MapGroups( + func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], + encoderFor[K].fromRowExpression, + encoderFor[T].fromRowExpression, + encoderFor[U].namedExpressions, + groupingAttributes, + child) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + * @param keyObject used to extract the key object for each group. + * @param input used to extract the items in the iterator from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapGroups( + func: (Any, Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + input: Expression, + serializer: Seq[NamedExpression], + groupingAttributes: Seq[Attribute], + child: LogicalPlan) extends UnaryNode with ObjectOperator { + + def output: Seq[Attribute] = serializer.map(_.toAttribute) +} + +/** Factory for constructing new `CoGroup` nodes. */ +object CoGroup { + def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan): CoGroup = { + CoGroup( + func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], + encoderFor[Key].fromRowExpression, + encoderFor[Left].fromRowExpression, + encoderFor[Right].fromRowExpression, + encoderFor[Result].namedExpressions, + leftGroup, + rightGroup, + left, + right) + } +} + +/** + * A relation produced by applying `func` to each grouping key and associated values from left and + * right children. + */ +case class CoGroup( + func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + leftObject: Expression, + rightObject: Expression, + serializer: Seq[NamedExpression], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode with ObjectOperator { + override def producedAttributes: AttributeSet = outputSet + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..91777375608fd0fa1e22712376f5fedd3f3329b1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.NewInstance +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +case class OtherTuple(_1: Int, _2: Int) + +class EliminateSerializationSuite extends PlanTest { + private object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Serialization", FixedPoint(100), + EliminateSerialization) :: Nil + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + private val func = identity[Iterator[(Int, Int)]] _ + private val func2 = identity[Iterator[OtherTuple]] _ + + def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = { + val newInstances = plan.flatMap(_.expressions.collect { + case n: NewInstance => n + }) + + if (newInstances.size != count) { + fail( + s""" + |Wrong number of object creations in plan: ${newInstances.size} != $count + |$plan + """.stripMargin) + } + } + + test("back to back MapPartitions") { + val input = LocalRelation('_1.int, '_2.int) + val plan = + MapPartitions(func, + MapPartitions(func, input)) + + val optimized = Optimize.execute(plan.analyze) + assertObjectCreations(1, optimized) + } + + test("back to back with object change") { + val input = LocalRelation('_1.int, '_2.int) + val plan = + MapPartitions(func, + MapPartitions(func2, input)) + + val optimized = Optimize.execute(plan.analyze) + assertObjectCreations(2, optimized) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 42f01e9359c6488e092cf21eb35e5c540fa836a2..9a9f7d111cf4bc9e0f8ba3acda7b1fa88a4fa12c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -336,12 +336,7 @@ class Dataset[T] private[sql]( def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { new Dataset[U]( sqlContext, - MapPartitions[T, U]( - func, - resolvedTEncoder, - encoderFor[U], - encoderFor[U].schema.toAttributes, - logicalPlan)) + MapPartitions[T, U](func, logicalPlan)) } /** @@ -434,7 +429,7 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = logicalPlan - val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) + val withGroupingKey = AppendColumns(func, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a819ddceb1b1baa90ce5a5e3a146ca014669a2d0..b3f8284364782026eb4abf5b88d8eb0f38ac5c9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -115,8 +115,6 @@ class GroupedDataset[K, V] private[sql]( sqlContext, MapGroups( f, - resolvedKEncoder, - resolvedVEncoder, groupingAttributes, logicalPlan)) } @@ -305,13 +303,11 @@ class GroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { + implicit val uEncoder = other.unresolvedVEncoder new Dataset[R]( sqlContext, CoGroup( f, - this.resolvedKEncoder, - this.resolvedVEncoder, - other.resolvedVEncoder, this.groupingAttributes, other.groupingAttributes, this.logicalPlan, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 482130a18d9397a993ba79a935caa908527fa803..910519d0e68144b1a38fbe1c97876ef6c7e05665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -309,16 +309,15 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") - case logical.MapPartitions(f, tEnc, uEnc, output, child) => - execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil - case logical.AppendColumns(f, tEnc, uEnc, newCol, child) => - execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil - case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => - execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil - case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, - leftGroup, rightGroup, left, right) => - execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup, - planLater(left), planLater(right)) :: Nil + case logical.MapPartitions(f, in, out, child) => + execution.MapPartitions(f, in, out, planLater(child)) :: Nil + case logical.AppendColumns(f, in, out, child) => + execution.AppendColumns(f, in, out, planLater(child)) :: Nil + case logical.MapGroups(f, key, in, out, grouping, child) => + execution.MapGroups(f, key, in, out, grouping, planLater(child)) :: Nil + case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, left, right) => + execution.CoGroup( + f, keyObj, lObj, rObj, out, lGroup, rGroup, planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 95bef683238a7b9110648e6e36fa00049834ee2d..92c9a561312bab04a0c8abdd5cf1d8d3fdb97db0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -21,9 +21,7 @@ import org.apache.spark.{HashPartitioner, SparkEnv} import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType @@ -329,128 +327,3 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl protected override def doExecute(): RDD[InternalRow] = child.execute() } - -/** - * Applies the given function to each input row and encodes the result. - */ -case class MapPartitions[T, U]( - func: Iterator[T] => Iterator[U], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - output: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => - val tBoundEncoder = tEncoder.bind(child.output) - func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) - } - } -} - -/** - * Applies the given function to each input row, appending the encoded result at the end of the row. - */ -case class AppendColumns[T, U]( - func: T => U, - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - newColumns: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = AttributeSet(newColumns) - - override def output: Seq[Attribute] = child.output ++ newColumns - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => - val tBoundEncoder = tEncoder.bind(child.output) - val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) - iter.map { row => - val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) - combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow - } - } - } -} - -/** - * Groups the input rows together and calls the function with each group and an iterator containing - * all elements in the group. The result of this function is encoded and flattened before - * being output. - */ -case class MapGroups[K, T, U]( - func: (K, Iterator[T]) => TraversableOnce[U], - kEncoder: ExpressionEncoder[K], - tEncoder: ExpressionEncoder[T], - uEncoder: ExpressionEncoder[U], - groupingAttributes: Seq[Attribute], - output: Seq[Attribute], - child: SparkPlan) extends UnaryNode { - override def producedAttributes: AttributeSet = outputSet - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsInternal { iter => - val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val groupKeyEncoder = kEncoder.bind(groupingAttributes) - val groupDataEncoder = tEncoder.bind(child.output) - - grouped.flatMap { case (key, rowIter) => - val result = func( - groupKeyEncoder.fromRow(key), - rowIter.map(groupDataEncoder.fromRow)) - result.map(uEncoder.toRow) - } - } - } -} - -/** - * Co-groups the data from left and right children, and calls the function with each group and 2 - * iterators containing all elements in the group from left and right side. - * The result of this function is encoded and flattened before being output. - */ -case class CoGroup[Key, Left, Right, Result]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], - keyEnc: ExpressionEncoder[Key], - leftEnc: ExpressionEncoder[Left], - rightEnc: ExpressionEncoder[Right], - resultEnc: ExpressionEncoder[Result], - output: Seq[Attribute], - leftGroup: Seq[Attribute], - rightGroup: Seq[Attribute], - left: SparkPlan, - right: SparkPlan) extends BinaryNode { - override def producedAttributes: AttributeSet = outputSet - - override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil - - override protected def doExecute(): RDD[InternalRow] = { - left.execute().zipPartitions(right.execute()) { (leftData, rightData) => - val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) - val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val boundKeyEnc = keyEnc.bind(leftGroup) - val boundLeftEnc = leftEnc.bind(left.output) - val boundRightEnc = rightEnc.bind(right.output) - - new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { - case (key, leftResult, rightResult) => - val result = func( - boundKeyEnc.fromRow(key), - leftResult.map(boundLeftEnc.fromRow), - rightResult.map(boundRightEnc.fromRow)) - result.map(resultEnc.toRow) - } - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala new file mode 100644 index 0000000000000000000000000000000000000000..2acca1743cbb905dc7e637b0f0d18f1a912dfdbb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.types.ObjectType + +/** + * Helper functions for physical operators that work with user defined objects. + */ +trait ObjectOperator extends SparkPlan { + def generateToObject(objExpr: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = { + val objectProjection = GenerateSafeProjection.generate(objExpr :: Nil, inputSchema) + (i: InternalRow) => objectProjection(i).get(0, objExpr.dataType) + } + + def generateToRow(serializer: Seq[Expression]): Any => InternalRow = { + val outputProjection = if (serializer.head.dataType.isInstanceOf[ObjectType]) { + GenerateSafeProjection.generate(serializer) + } else { + GenerateUnsafeProjection.generate(serializer) + } + val inputType = serializer.head.collect { case b: BoundReference => b.dataType }.head + val outputRow = new SpecificMutableRow(inputType :: Nil) + (o: Any) => { + outputRow(0) = o + outputProjection(outputRow) + } + } +} + +/** + * Applies the given function to each input row and encodes the result. + */ +case class MapPartitions( + func: Iterator[Any] => Iterator[Any], + input: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(input, child.output) + val outputObject = generateToRow(serializer) + func(iter.map(getObject)).map(outputObject) + } + } +} + +/** + * Applies the given function to each input row, appending the encoded result at the end of the row. + */ +case class AppendColumns( + func: Any => Any, + input: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute) + + private def newColumnSchema = serializer.map(_.toAttribute).toStructType + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(input, child.output) + val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) + val outputObject = generateToRow(serializer) + + iter.map { row => + val newColumns = outputObject(func(getObject(row))) + + // This operates on the assumption that we always serialize the result... + combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow + } + } + } +} + +/** + * Groups the input rows together and calls the function with each group and an iterator containing + * all elements in the group. The result of this function is encoded and flattened before + * being output. + */ +case class MapGroups( + func: (Any, Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + input: Expression, + serializer: Seq[NamedExpression], + groupingAttributes: Seq[Attribute], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + + val getKey = generateToObject(keyObject, groupingAttributes) + val getValue = generateToObject(input, child.output) + val outputObject = generateToRow(serializer) + + grouped.flatMap { case (key, rowIter) => + val result = func( + getKey(key), + rowIter.map(getValue)) + result.map(outputObject) + } + } + } +} + +/** + * Co-groups the data from left and right children, and calls the function with each group and 2 + * iterators containing all elements in the group from left and right side. + * The result of this function is encoded and flattened before being output. + */ +case class CoGroup( + func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], + keyObject: Expression, + leftObject: Expression, + rightObject: Expression, + serializer: Seq[NamedExpression], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) extends BinaryNode with ObjectOperator { + + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + + val getKey = generateToObject(keyObject, leftGroup) + val getLeft = generateToObject(leftObject, left.output) + val getRight = generateToObject(rightObject, right.output) + val outputObject = generateToRow(serializer) + + new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { + case (key, leftResult, rightResult) => + val result = func( + getKey(key), + leftResult.map(getLeft), + rightResult.map(getRight)) + result.map(outputObject) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d7b86e381108ed7b7444860c589f2d9ddddda493..b69bb21db532b14082e19a0efc26bda92077ab07 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +case class OtherTuple(_1: String, _2: Int) + class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -111,6 +113,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 2), ("b", 3), ("c", 4)) } + test("map with type change") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + + checkAnswer( + ds.map(identity[(String, Int)]) + .as[OtherTuple] + .map(identity[OtherTuple]), + OtherTuple("a", 1), OtherTuple("b", 2), OtherTuple("c", 3)) + } + test("map and group by with class data") { // We inject a group by here to make sure this test case is future proof // when we implement better pipelining and local execution mode. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index fac26bd0c0269c17abfabf264f7585504332e64e..ce12f788b786c0cae2e09aa0ef7c24d4f449d75e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -192,10 +192,10 @@ abstract class QueryTest extends PlanTest { val logicalPlan = df.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: MapPartitions[_, _] => return - case _: MapGroups[_, _, _] => return - case _: AppendColumns[_, _] => return - case _: CoGroup[_, _, _, _] => return + case _: MapPartitions => return + case _: MapGroups => return + case _: AppendColumns => return + case _: CoGroup => return case _: LogicalRelation => return }.transformAllExpressions { case a: ImperativeAggregate => return