diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 27c96f41221adaca13b927887f6b0c9e29fa5062..713c6b547d9b7881819f300211adbc99551da619 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -411,9 +411,9 @@ trait ScalaReflection { } /** Returns expressions for extracting all the fields from the given type. */ - def extractorsFor[T : TypeTag](inputObject: Expression): Seq[Expression] = { + def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateStruct].children + extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct] } } @@ -497,11 +497,11 @@ trait ScalaReflection { } } - CreateStruct(params.head.map { p => + CreateNamedStruct(params.head.flatMap { p => val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - extractorFor(fieldValue, fieldType) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil }) case t if t <:< localTypeOf[Array[_]] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala index 54096f18cbea1d332c352c1298f050ce4c92b613..b484b8fde63692b7c73cfedb4a69144d8b4dea9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.encoders import scala.reflect.ClassTag import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} @@ -41,9 +41,11 @@ case class ClassEncoder[T]( clsTag: ClassTag[T]) extends Encoder[T] { - private val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + @transient + private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) private val inputRow = new GenericMutableRow(1) + @transient private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) private val dataType = ObjectType(clsTag.runtimeClass) @@ -64,4 +66,36 @@ case class ClassEncoder[T]( copy(constructExpression = boundExpression) } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(oldSchema) + val attributeToNewPosition = AttributeMap.byIndex(newSchema) + copy(constructExpression = constructExpression transform { + case r: BoundReference => + r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) + }) + } + + override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = { + var remaining = schema + copy(constructExpression = constructExpression transform { + case u: UnresolvedAttribute => + val pos = remaining.head + remaining = remaining.drop(1) + pos + }) + } + + protected val attrs = extractExpressions.map(_.collect { + case a: Attribute => s"#${a.exprId}" + case b: BoundReference => s"[${b.ordinal}]" + }.headOption.getOrElse("")) + + + protected val schemaString = + schema + .zip(attrs) + .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") + + override def toString: String = s"class[$schemaString]" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index bdb1c0959da871a5f27249da57ce8c1ae622bcc4..efb872ddb81e52ba1ca35222aefb5af36979739c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.types.StructType * and reuse internal buffers to improve performance. */ trait Encoder[T] { + /** Returns the schema of encoding this type of object as a Row. */ def schema: StructType @@ -46,13 +47,27 @@ trait Encoder[T] { /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must bind the encoder to a specific schema before you can call this function. + * you must `bind` an encoder to a specific schema before you can call this function. */ def fromRow(row: InternalRow): T /** * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the - * given schema + * given schema. */ def bind(schema: Seq[Attribute]): Encoder[T] + + /** + * Binds this encoder to the given schema positionally. In this binding, the first reference to + * any input is mapped to `schema(0)`, and so on for each input that is encountered. + */ + def bindOrdinals(schema: Seq[Attribute]): Encoder[T] + + /** + * Given an encoder that has already been bound to a given schema, returns a new encoder that + * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, + * when you are trying to use an encoder on grouping keys that were orriginally part of a larger + * row, but now you have projected out only the key expressions. + */ + def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala index 4f7ce455ada99d46cb3c44a225125e03e87b0844..34f5e6c030f5889ecf32ebb433b5b0156aeca208 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -31,15 +31,17 @@ import org.apache.spark.sql.types.{ObjectType, StructType} object ProductEncoder { def apply[T <: Product : TypeTag]: ClassEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. - val schema = ScalaReflection.schemaFor[T].dataType.asInstanceOf[StructType] val mirror = typeTag[T].mirror val cls = mirror.runtimeClass(typeTag[T].tpe) val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpressions = ScalaReflection.extractorsFor[T](inputObject) + val extractExpression = ScalaReflection.extractorsFor[T](inputObject) val constructExpression = ScalaReflection.constructorFor[T] - new ClassEncoder[T](schema, extractExpressions, constructExpression, ClassTag[T](cls)) - } - + new ClassEncoder[T]( + extractExpression.dataType, + extractExpression.flatten, + constructExpression, + ClassTag[T](cls)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala new file mode 100644 index 0000000000000000000000000000000000000000..a93f2d7c6115dad6363617123275827163b3bfbe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ + +/** An encoder for primitive Long types. */ +case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] { + private val row = UnsafeRow.createFromByteArray(64, 1) + + override def clsTag: ClassTag[Long] = ClassTag.Long + override def schema: StructType = + StructType(StructField(fieldName, LongType) :: Nil) + + override def fromRow(row: InternalRow): Long = row.getLong(ordinal) + + override def toRow(t: Long): InternalRow = { + row.setLong(ordinal, t) + row + } + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this + override def bind(schema: Seq[Attribute]): Encoder[Long] = this + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this +} + +/** An encoder for primitive Integer types. */ +case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] { + private val row = UnsafeRow.createFromByteArray(64, 1) + + override def clsTag: ClassTag[Int] = ClassTag.Int + override def schema: StructType = + StructType(StructField(fieldName, IntegerType) :: Nil) + + override def fromRow(row: InternalRow): Int = row.getInt(ordinal) + + override def toRow(t: Int): InternalRow = { + row.setInt(ordinal, t) + row + } + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this + override def bind(schema: Seq[Attribute]): Encoder[Int] = this + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this +} + +/** An encoder for String types. */ +case class StringEncoder( + fieldName: String = "value", + ordinal: Int = 0) extends Encoder[String] { + + val record = new SpecificMutableRow(StringType :: Nil) + + @transient + lazy val projection = + GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil) + + override def schema: StructType = + StructType( + StructField("value", StringType, nullable = false) :: Nil) + + override def clsTag: ClassTag[String] = scala.reflect.classTag[String] + + + override final def fromRow(row: InternalRow): String = { + row.getString(ordinal) + } + + override final def toRow(value: String): InternalRow = { + val utf8String = UTF8String.fromString(value) + record(0) = utf8String + // TODO: this is a bit of a hack to produce UnsafeRows + projection(record) + } + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this + override def bind(schema: Seq[Attribute]): Encoder[String] = this + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala new file mode 100644 index 0000000000000000000000000000000000000000..a48eeda7d2e6f034e6a5e006d6ec200bff31e5ff --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{StructField, StructType} + +// Most of this file is codegen. +// scalastyle:off + +/** + * A set of composite encoders that take sub encoders and map each of their objects to a + * Scala tuple. Note that currently the implementation is fairly limited and only supports going + * from an internal row to a tuple. + */ +object TupleEncoder { + + /** Code generator for composite tuple encoders. */ + def main(args: Array[String]): Unit = { + (2 to 5).foreach { i => + val types = (1 to i).map(t => s"T$t").mkString(", ") + val tupleType = s"($types)" + val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ") + val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ") + val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ") + + println( + s""" + |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] { + | val schema = StructType(Array($fields)) + | + | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType] + | + | def fromRow(row: InternalRow): $tupleType = { + | ($fromRow) + | } + | + | override def toRow(t: $tupleType): InternalRow = + | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + | + | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = { + | this + | } + | + | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] = + | throw new UnsupportedOperationException("Tuple Encoders only support bind.") + | + | + | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] = + | throw new UnsupportedOperationException("Tuple Encoders only support bind.") + |} + """.stripMargin) + } + } +} + +class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema))) + + def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)] + + def fromRow(row: InternalRow): (T1, T2) = { + (e1.fromRow(row), e2.fromRow(row)) + } + + override def toRow(t: (T1, T2)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} + + +class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema))) + + def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)] + + def fromRow(row: InternalRow): (T1, T2, T3) = { + (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row)) + } + + override def toRow(t: (T1, T2, T3)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} + + +class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema))) + + def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)] + + def fromRow(row: InternalRow): (T1, T2, T3, T4) = { + (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row)) + } + + override def toRow(t: (T1, T2, T3, T4)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} + + +class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] { + val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema))) + + def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)] + + def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = { + (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row)) + } + + override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow = + throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") + + override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = { + this + } + + override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") + + + override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = + throw new UnsupportedOperationException("Tuple Encoders only support bind.") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index 96a11e352ec509176d66334d5100bd0a12095525..ef3cc554b79c07217342899a7d9a0f6dc87a1b82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -26,6 +26,13 @@ object AttributeMap { def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } + + /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */ + def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex) + + /** Given a schema, constructs a map from ordinal to Attribute. */ + def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] = + schema.zipWithIndex.map { case (a, i) => i -> a }.toMap } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala index 5345696570b41295fdec9e08eb8bfe729109641f..383153557420536e6b8cd6bdb08f78e1530a4c49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala @@ -31,6 +31,10 @@ protected class AttributeEquals(val a: Attribute) { } object AttributeSet { + /** Returns an empty [[AttributeSet]]. */ + val empty = apply(Iterable.empty) + + /** Constructs a new [[AttributeSet]] that contains a single [[Attribute]]. */ def apply(a: Attribute): AttributeSet = new AttributeSet(Set(new AttributeEquals(a))) /** Constructs a new [[AttributeSet]] given a sequence of [[Expression Expressions]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a5f02e2463aed36c29c9d0a5c0d35e6c2e9f2eb5..059e45bd684edd86975203aa2f3abe03cc577f35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -125,6 +125,14 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { */ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { + /** + * Returns Aliased [[Expressions]] that could be used to construct a flattened version of this + * StructType. + */ + def flatten: Seq[NamedExpression] = valExprs.zip(names).map { + case (v, n) => Alias(v, n.toString)() + } + private lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 30b7f8d3766a5f1c22c53f3ed8ef0939f208ff27..f1fa13daa77ebf0a52f7c9e4c6001db16d69215c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.{StructField, StructType} /** * A set of classes that can be used to represent trees of relational expressions. A key goal of @@ -80,4 +81,15 @@ package object expressions { /** Uses the given row to store the output of the projection. */ def target(row: MutableRow): MutableProjection } + + + /** + * Helper functions for working with `Seq[Attribute]`. + */ + implicit class AttributeSeq(attrs: Seq[Attribute]) { + /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ + def toStructType: StructType = { + StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ae9482c10f126be9450b75e822f5a6139380af1b..21a55a53718410d5d70d6fbbf90b87034628a794 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ @@ -417,7 +418,7 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { } /** - * Return a new RDD that has exactly `numPartitions` partitions. Differs from + * Returns a new RDD that has exactly `numPartitions` partitions. Differs from * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user * asked for `coalesce` or `repartition`. [[RepartitionByExpression]] is used when the consumer * of the output requires some specific ordering or distribution of the data. @@ -443,3 +444,72 @@ case object OneRowRelation extends LeafNode { override def statistics: Statistics = Statistics(sizeInBytes = 1) } +/** + * A relation produced by applying `func` to each partition of the `child`. tEncoder/uEncoder are + * used respectively to decode/encode from the JVM object representation expected by `func.` + */ +case class MapPartitions[T, U]( + func: Iterator[T] => Iterator[U], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def missingInput: AttributeSet = AttributeSet.empty +} + +/** Factory for constructing new `AppendColumn` nodes. */ +object AppendColumn { + def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { + val attrs = implicitly[Encoder[U]].schema.toAttributes + new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child) + } +} + +/** + * A relation produced by applying `func` to each partition of the `child`, concatenating the + * resulting columns at the end of the input row. tEncoder/uEncoder are used respectively to + * decode/encode from the JVM object representation expected by `func.` + */ +case class AppendColumn[T, U]( + func: T => U, + tEncoder: Encoder[T], + uEncoder: Encoder[U], + newColumns: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output ++ newColumns + override def missingInput: AttributeSet = super.missingInput -- newColumns +} + +/** Factory for constructing new `MapGroups` nodes. */ +object MapGroups { + def apply[K : Encoder, T : Encoder, U : Encoder]( + func: (K, Iterator[T]) => Iterator[U], + groupingAttributes: Seq[Attribute], + child: LogicalPlan): MapGroups[K, T, U] = { + new MapGroups( + func, + implicitly[Encoder[K]], + implicitly[Encoder[T]], + implicitly[Encoder[U]], + groupingAttributes, + implicitly[Encoder[U]].schema.toAttributes, + child) + } +} + +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + */ +case class MapGroups[K, T, U]( + func: (K, Iterator[T]) => Iterator[U], + kEncoder: Encoder[K], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + groupingAttributes: Seq[Attribute], + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + override def missingInput: AttributeSet = AttributeSet.empty +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..52f8383faca92fc43eeb0d3e1bafadbcc85a0702 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.SparkFunSuite + +class PrimitiveEncoderSuite extends SparkFunSuite { + test("long encoder") { + val enc = new LongEncoder() + val row = enc.toRow(10) + assert(row.getLong(0) == 10) + assert(enc.fromRow(row) == 10) + } + + test("int encoder") { + val enc = new IntEncoder() + val row = enc.toRow(10) + assert(row.getInt(0) == 10) + assert(enc.fromRow(row) == 10) + } + + test("string encoder") { + val enc = new StringEncoder() + val row = enc.toRow("test") + assert(row.getString(0) == "test") + assert(enc.fromRow(row) == "test") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala index 02e43ddb35478c3eb895e500b74a0eff1e373346..7735acbcbad415f0a20d4a50ab426e2fdac8b8a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -248,12 +248,16 @@ class ProductEncoderSuite extends SparkFunSuite { val types = convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") - val encodedData = convertedData.toSeq(encoder.schema).zip(encoder.schema).map { - case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => - a.toArray[Any](at.elementType).toSeq - case (other, _) => - other - }.mkString("[", ",", "]") + val encodedData = try { + convertedData.toSeq(encoder.schema).zip(encoder.schema).map { + case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => + a.toArray[Any](at.elementType).toSeq + case (other, _) => + other + }.mkString("[", ",", "]") + } catch { + case e: Throwable => s"Failed to toSeq: $e" + } fail( s"""Encoded/Decoded data does not match input data @@ -272,8 +276,9 @@ class ProductEncoderSuite extends SparkFunSuite { |Construct Expressions: |${boundEncoder.constructExpression.treeString} | - """.stripMargin) + """.stripMargin) + } } - } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 37d559c8e4301fcb93daa251fe54ba5f279c64e6..de11a1699afd959691382b5ba570883e2a3395ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql + import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -36,6 +38,11 @@ private[sql] object Column { def unapply(col: Column): Option[Expression] = Some(col.expr) } +/** + * A [[Column]] where an [[Encoder]] has been given for the expected return type. + * @since 1.6.0 + */ +class TypedColumn[T](expr: Expression)(implicit val encoder: Encoder[T]) extends Column(expr) /** * :: Experimental :: @@ -69,6 +76,14 @@ class Column(protected[sql] val expr: Expression) extends Logging { override def hashCode: Int = this.expr.hashCode + /** + * Provides a type hint about the expected return value of this column. This information can + * be used by operations such as `select` on a [[Dataset]] to automatically convert the + * results into the correct JVM types. + * @since 1.6.0 + */ + def as[T : Encoder]: TypedColumn[T] = new TypedColumn[T](expr) + /** * Extracts a value or values from a complex type. * The following types of extraction are supported: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 2f10aa9f3c4469b8702170774194575c3ba06f9d..bf25bcde208e269909b4f20e2ba234fe0d441e30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} @@ -258,6 +259,16 @@ class DataFrame private[sql]( // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. def toDF(): DataFrame = this + /** + * :: Experimental :: + * Converts this [[DataFrame]] to a strongly-typed [[Dataset]] containing objects of the + * specified type, `U`. + * @group basic + * @since 1.6.0 + */ + @Experimental + def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution) + /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala new file mode 100644 index 0000000000000000000000000000000000000000..96213c7630400c2e93f77a98a1b4c13517e7be1d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.types.StructType + +/** + * A [[Dataset]] is a strongly typed collection of objects that can be transformed in parallel + * using functional or relational operations. + * + * A [[Dataset]] differs from an [[RDD]] in the following ways: + * - Internally, a [[Dataset]] is represented by a Catalyst logical plan and the data is stored + * in the encoded form. This representation allows for additional logical operations and + * enables many operations (sorting, shuffling, etc.) to be performed without deserializing to + * an object. + * - The creation of a [[Dataset]] requires the presence of an explicit [[Encoder]] that can be + * used to serialize the object into a binary format. Encoders are also capable of mapping the + * schema of a given object to the Spark SQL type system. In contrast, RDDs rely on runtime + * reflection based serialization. Operations that change the type of object stored in the + * dataset also need an encoder for the new type. + * + * A [[Dataset]] can be thought of as a specialized DataFrame, where the elements map to a specific + * JVM object type, instead of to a generic [[Row]] container. A DataFrame can be transformed into + * specific Dataset by calling `df.as[ElementType]`. Similarly you can transform a strongly-typed + * [[Dataset]] to a generic DataFrame by calling `ds.toDF()`. + * + * COMPATIBILITY NOTE: Long term we plan to make [[DataFrame]] extend `Dataset[Row]`. However, + * making this change to the class hierarchy would break the function signatures for the existing + * functional operations (map, flatMap, etc). As such, this class should be considered a preview + * of the final API. Changes will be made to the interface after Spark 1.6. + * + * @since 1.6.0 + */ +@Experimental +class Dataset[T] private[sql]( + @transient val sqlContext: SQLContext, + @transient val queryExecution: QueryExecution)( + implicit val encoder: Encoder[T]) extends Serializable { + + private implicit def classTag = encoder.clsTag + + private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = + this(sqlContext, new QueryExecution(sqlContext, plan)) + + /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ + def schema: StructType = encoder.schema + + /* ************* * + * Conversions * + * ************* */ + + /** + * Returns a new `Dataset` where each record has been mapped on to the specified type. + * TODO: should bind here... + * TODO: document binding rules + * @since 1.6.0 + */ + def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]]) + + /** + * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have + * the same name after two Datasets have been joined. + */ + def as(alias: String): Dataset[T] = withPlan(Subquery(alias, _)) + + /** + * Converts this strongly typed collection of data to generic Dataframe. In contrast to the + * strongly typed objects that Dataset operations work on, a Dataframe returns generic [[Row]] + * objects that allow fields to be accessed by ordinal or name. + */ + def toDF(): DataFrame = DataFrame(sqlContext, logicalPlan) + + + /** + * Returns this Dataset. + * @since 1.6.0 + */ + def toDS(): Dataset[T] = this + + /** + * Converts this Dataset to an RDD. + * @since 1.6.0 + */ + def rdd: RDD[T] = { + val tEnc = implicitly[Encoder[T]] + val input = queryExecution.analyzed.output + queryExecution.toRdd.mapPartitions { iter => + val bound = tEnc.bind(input) + iter.map(bound.fromRow) + } + } + + /* *********************** * + * Functional Operations * + * *********************** */ + + /** + * Concise syntax for chaining custom transformations. + * {{{ + * def featurize(ds: Dataset[T]) = ... + * + * dataset + * .transform(featurize) + * .transform(...) + * }}} + * + * @since 1.6.0 + */ + def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this) + + /** + * Returns a new [[Dataset]] that only contains elements where `func` returns `true`. + * @since 1.6.0 + */ + def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + + /** + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + + /** + * Returns a new [[Dataset]] that contains the result of applying `func` to each element. + * @since 1.6.0 + */ + def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { + new Dataset( + sqlContext, + MapPartitions[T, U]( + func, + implicitly[Encoder[T]], + implicitly[Encoder[U]], + implicitly[Encoder[U]].schema.toAttributes, + logicalPlan)) + } + + def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U] = + mapPartitions(_.flatMap(func)) + + /* ************** * + * Side effects * + * ************** */ + + /** + * Runs `func` on each element of this Dataset. + * @since 1.6.0 + */ + def foreach(func: T => Unit): Unit = rdd.foreach(func) + + /** + * Runs `func` on each partition of this Dataset. + * @since 1.6.0 + */ + def foreachPartition(func: Iterator[T] => Unit): Unit = rdd.foreachPartition(func) + + /* ************* * + * Aggregation * + * ************* */ + + /** + * Reduces the elements of this Dataset using the specified binary function. The given function + * must be commutative and associative or the result may be non-deterministic. + * @since 1.6.0 + */ + def reduce(func: (T, T) => T): T = rdd.reduce(func) + + /** + * Aggregates the elements of each partition, and then the results for all the partitions, using a + * given associative and commutative function and a neutral "zero value". + * + * This behaves somewhat differently than the fold operations implemented for non-distributed + * collections in functional languages like Scala. This fold operation may be applied to + * partitions individually, and then those results will be folded into the final result. + * If op is not commutative, then the result may differ from that of a fold applied to a + * non-distributed collection. + * @since 1.6.0 + */ + def fold(zeroValue: T)(op: (T, T) => T): T = rdd.fold(zeroValue)(op) + + /** + * Returns a [[GroupedDataset]] where the data is grouped by the given key function. + * @since 1.6.0 + */ + def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { + val inputPlan = queryExecution.analyzed + val withGroupingKey = AppendColumn(func, inputPlan) + val executed = sqlContext.executePlan(withGroupingKey) + + new GroupedDataset( + implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns), + implicitly[Encoder[T]].bind(inputPlan.output), + executed, + inputPlan.output, + withGroupingKey.newColumns) + } + + /* ****************** * + * Typed Relational * + * ****************** */ + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. + * + * {{{ + * val ds = Seq(1, 2, 3).toDS() + * val newDS = ds.select(e[Int]("value + 1")) + * }}} + * @since 1.6.0 + */ + def select[U1: Encoder](c1: TypedColumn[U1]): Dataset[U1] = { + new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) + } + + // Codegen + // scalastyle:off + + /** sbt scalaShell; println(Seq(1).toDS().genSelect) */ + private def genSelect: String = { + (2 to 5).map { n => + val types = (1 to n).map(i =>s"U$i").mkString(", ") + val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ") + val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ") + val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ") + s""" + |/** + | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + | * @since 1.6.0 + | */ + |def select[$types]($args): Dataset[($types)] = { + | implicit val te = new Tuple${n}Encoder($encoders) + | new Dataset[($types)](sqlContext, + | Project( + | $schema :: Nil, + | logicalPlan)) + |} + | + """.stripMargin + }.mkString("\n") + } + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = { + implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder) + new Dataset[(U1, U2)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil, + logicalPlan)) + } + + + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = { + implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder) + new Dataset[(U1, U2, U3)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil, + logicalPlan)) + } + + + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = { + implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder) + new Dataset[(U1, U2, U3, U4)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil, + logicalPlan)) + } + + + + /** + * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. + * @since 1.6.0 + */ + def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = { + implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder) + new Dataset[(U1, U2, U3, U4, U5)](sqlContext, + Project( + Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil, + logicalPlan)) + } + + // scalastyle:on + + /* **************** * + * Set operations * + * **************** */ + + /** + * Returns a new [[Dataset]] that contains only the unique elements of this [[Dataset]]. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def distinct: Dataset[T] = withPlan(Distinct) + + /** + * Returns a new [[Dataset]] that contains only the elements of this [[Dataset]] that are also + * present in `other`. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def intersect(other: Dataset[T]): Dataset[T] = + withPlan[T](other)(Intersect) + + /** + * Returns a new [[Dataset]] that contains the elements of both this and the `other` [[Dataset]] + * combined. + * + * Note that, this function is not a typical set union operation, in that it does not eliminate + * duplicate items. As such, it is analagous to `UNION ALL` in SQL. + * @since 1.6.0 + */ + def union(other: Dataset[T]): Dataset[T] = + withPlan[T](other)(Union) + + /** + * Returns a new [[Dataset]] where any elements present in `other` have been removed. + * + * Note that, equality checking is performed directly on the encoded representation of the data + * and thus is not affected by a custom `equals` function defined on `T`. + * @since 1.6.0 + */ + def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except) + + /* ************************** * + * Gather to Driver Actions * + * ************************** */ + + /** Returns the first element in this [[Dataset]]. */ + def first(): T = rdd.first() + + /** Collects the elements to an Array. */ + def collect(): Array[T] = rdd.collect() + + /** Returns the first `num` elements of this [[Dataset]] as an Array. */ + def take(num: Int): Array[T] = rdd.take(num) + + /* ******************** * + * Internal Functions * + * ******************** */ + + private[sql] def logicalPlan = queryExecution.analyzed + + private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan))) + + private[sql] def withPlan[R : Encoder]( + other: Dataset[_])( + f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = + new Dataset[R]( + sqlContext, + sqlContext.executePlan( + f(logicalPlan, other.logicalPlan))) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala new file mode 100644 index 0000000000000000000000000000000000000000..17817cbcc5e050dc9081c57cf8aa66e1163739ee --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -0,0 +1,30 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +/** + * A container for a [[DataFrame]], used for implicit conversions. + * + * @since 1.3.0 + */ +private[sql] case class DatasetHolder[T](df: Dataset[T]) { + + // This is declared with parentheses to prevent the Scala compiler from treating + // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. + def toDS(): Dataset[T] = df +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala new file mode 100644 index 0000000000000000000000000000000000000000..89a16dd8b0accb46f10e02aaca47c3c4673e4094 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.QueryExecution + +/** + * A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not + * construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing + * [[Dataset]]. + */ +class GroupedDataset[K, T] private[sql]( + private val kEncoder: Encoder[K], + private val tEncoder: Encoder[T], + queryExecution: QueryExecution, + private val dataAttributes: Seq[Attribute], + private val groupingAttributes: Seq[Attribute]) extends Serializable { + + private implicit def kEnc = kEncoder + private implicit def tEnc = tEncoder + private def logicalPlan = queryExecution.analyzed + private def sqlContext = queryExecution.sqlContext + + /** + * Returns a [[Dataset]] that contains each unique key. + */ + def keys: Dataset[K] = { + new Dataset[K]( + sqlContext, + Distinct( + Project(groupingAttributes, logicalPlan))) + } + + /** + * Applies the given function to each group of data. For each unique group, the function will + * be passed the group key and an iterator that contains all of the elements in the group. The + * function can return an iterator containing elements of an arbitrary type which will be returned + * as a new [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + */ + def mapGroups[U : Encoder](f: (K, Iterator[T]) => Iterator[U]): Dataset[U] = { + new Dataset[U]( + sqlContext, + MapGroups(f, groupingAttributes, logicalPlan)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a107639947aa2bc4fdc1b88e32596727de7c11c6..5e7198f974389d3ec2c0554892a6d82eff099472 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,6 +21,7 @@ import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference + import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag @@ -33,6 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} @@ -487,6 +489,16 @@ class SQLContext private[sql]( DataFrame(this, logicalPlan) } + + def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { + val enc = implicitly[Encoder[T]] + val attributes = enc.schema.toAttributes + val encoded = data.map(d => enc.toRow(d).copy()) + val plan = new LocalRelation(attributes, encoded) + + new Dataset[T](this, plan) + } + /** * Creates a DataFrame from an RDD[Row]. User can specify whether the input rows should be * converted to Catalyst rows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index bf03c610884263ca0f6dda7f17b87dc54dc06a48..af8474df0de808b1ea6ffd23adf53cb70c7b3432 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.execution.datasources.LogicalRelation + import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag @@ -30,9 +34,19 @@ import org.apache.spark.unsafe.types.UTF8String /** * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s. */ -private[sql] abstract class SQLImplicits { +abstract class SQLImplicits { protected def _sqlContext: SQLContext + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + + implicit def newIntEncoder: Encoder[Int] = new IntEncoder() + implicit def newLongEncoder: Encoder[Long] = new LongEncoder() + implicit def newStringEncoder: Encoder[String] = new StringEncoder() + + implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { + DatasetHolder(_sqlContext.createDataset(s)) + } + /** * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. * @since 1.3.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala new file mode 100644 index 0000000000000000000000000000000000000000..10742cf7348f803a4bca1631a6ba8463096339d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering} +import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression} + +object GroupedIterator { + def apply( + input: Iterator[InternalRow], + keyExpressions: Seq[Expression], + inputSchema: Seq[Attribute]): Iterator[(InternalRow, Iterator[InternalRow])] = { + if (input.hasNext) { + new GroupedIterator(input, keyExpressions, inputSchema) + } else { + Iterator.empty + } + } +} + +/** + * Iterates over a presorted set of rows, chunking it up by the grouping expression. Each call to + * next will return a pair containing the current group and an iterator that will return all the + * elements of that group. Iterators for each group are lazily constructed by extracting rows + * from the input iterator. As such, full groups are never materialized by this class. + * + * Example input: + * {{{ + * Input: [a, 1], [b, 2], [b, 3] + * Grouping: x#1 + * InputSchema: x#1, y#2 + * }}} + * + * Result: + * {{{ + * First call to next(): ([a], Iterator([a, 1]) + * Second call to next(): ([b], Iterator([b, 2], [b, 3]) + * }}} + * + * Note, the class does not handle the case of an empty input for simplicity of implementation. + * Use the factory to construct a new instance. + * + * @param input An iterator of rows. This iterator must be ordered by the groupingExpressions or + * it is possible for the same group to appear more than once. + * @param groupingExpressions The set of expressions used to do grouping. The result of evaluating + * these expressions will be returned as the first part of each call + * to `next()`. + * @param inputSchema The schema of the rows in the `input` iterator. + */ +class GroupedIterator private( + input: Iterator[InternalRow], + groupingExpressions: Seq[Expression], + inputSchema: Seq[Attribute]) + extends Iterator[(InternalRow, Iterator[InternalRow])] { + + /** Compares two input rows and returns 0 if they are in the same group. */ + val sortOrder = groupingExpressions.map(SortOrder(_, Ascending)) + val keyOrdering = GenerateOrdering.generate(sortOrder, inputSchema) + + /** Creates a row containing only the key for a given input row. */ + val keyProjection = GenerateUnsafeProjection.generate(groupingExpressions, inputSchema) + + /** + * Holds null or the row that will be returned on next call to `next()` in the inner iterator. + */ + var currentRow = input.next() + + /** Holds a copy of an input row that is in the current group. */ + var currentGroup = currentRow.copy() + var currentIterator: Iterator[InternalRow] = null + assert(keyOrdering.compare(currentGroup, currentRow) == 0) + + // Return true if we already have the next iterator or fetching a new iterator is successful. + def hasNext: Boolean = currentIterator != null || fetchNextGroupIterator + + def next(): (InternalRow, Iterator[InternalRow]) = { + assert(hasNext) // Ensure we have fetched the next iterator. + val ret = (keyProjection(currentGroup), currentIterator) + currentIterator = null + ret + } + + def fetchNextGroupIterator(): Boolean = { + if (currentRow != null || input.hasNext) { + val inputIterator = new Iterator[InternalRow] { + // Return true if we have a row and it is in the current group, or if fetching a new row is + // successful. + def hasNext = { + (currentRow != null && keyOrdering.compare(currentGroup, currentRow) == 0) || + fetchNextRowInGroup() + } + + def fetchNextRowInGroup(): Boolean = { + if (currentRow != null || input.hasNext) { + currentRow = input.next() + if (keyOrdering.compare(currentGroup, currentRow) == 0) { + // The row is in the current group. Continue the inner iterator. + true + } else { + // We got a row, but its not in the right group. End this inner iterator and prepare + // for the next group. + currentIterator = null + currentGroup = currentRow.copy() + false + } + } else { + // There is no more input so we are done. + false + } + } + + def next(): InternalRow = { + assert(hasNext) // Ensure we have fetched the next row. + val res = currentRow + currentRow = null + res + } + } + currentIterator = inputIterator + true + } else { + false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79bd1a41808dec066b22f05de7669621d38b644b..637deff4e22022380944223a3e0f24ef9f3949fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -372,6 +372,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") + + case logical.MapPartitions(f, tEnc, uEnc, output, child) => + execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil + case logical.AppendColumn(f, tEnc, uEnc, newCol, child) => + execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil + case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => + execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil + case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index dc38fe59feed5ff4f7dab53e2eeb437ee6e9e0c6..2bb3dba5bd2baf6f7c58d132a714590c67f32fe4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.Encoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.MutablePair @@ -311,3 +313,80 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl protected override def doExecute(): RDD[InternalRow] = child.execute() } + +/** + * Applies the given function to each input row and encodes the result. + */ +case class MapPartitions[T, U]( + func: Iterator[T] => Iterator[U], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val tBoundEncoder = tEncoder.bind(child.output) + func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow) + } + } +} + +/** + * Applies the given function to each input row, appending the encoded result at the end of the row. + */ +case class AppendColumns[T, U]( + func: T => U, + tEncoder: Encoder[T], + uEncoder: Encoder[U], + newColumns: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output ++ newColumns + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val tBoundEncoder = tEncoder.bind(child.output) + val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema) + iter.map { row => + val newColumns = uEncoder.toRow(func(tBoundEncoder.fromRow(row))) + combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow + } + } + } +} + +/** + * Groups the input rows together and calls the function with each group and an iterator containing + * all elements in the group. The result of this function is encoded and flattened before + * being output. + */ +case class MapGroups[K, T, U]( + func: (K, Iterator[T]) => Iterator[U], + kEncoder: Encoder[K], + tEncoder: Encoder[T], + uEncoder: Encoder[U], + groupingAttributes: Seq[Attribute], + output: Seq[Attribute], + child: SparkPlan) extends UnaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitions { iter => + val grouped = GroupedIterator(iter, groupingAttributes, child.output) + val groupKeyEncoder = kEncoder.bind(groupingAttributes) + + grouped.flatMap { case (key, rowIter) => + val result = func( + groupKeyEncoder.fromRow(key), + rowIter.map(tEncoder.fromRow)) + result.map(uEncoder.toRow) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..32443557fb8e0a1c64117abe3e23d9f5ccee8d62 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.test.SharedSQLContext + +case class IntClass(value: Int) + +class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(1, 2, 3, 4, 5, 6) + checkAnswer( + data.toDS(), + data: _*) + } + + test("as case class / collect") { + val ds = Seq(1, 2, 3).toDS().as[IntClass] + checkAnswer( + ds, + IntClass(1), IntClass(2), IntClass(3)) + + assert(ds.collect().head == IntClass(1)) + } + + test("map") { + val ds = Seq(1, 2, 3).toDS() + checkAnswer( + ds.map(_ + 1), + 2, 3, 4) + } + + test("filter") { + val ds = Seq(1, 2, 3, 4).toDS() + checkAnswer( + ds.filter(_ % 2 == 0), + 2, 4) + } + + test("foreach") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(acc +=) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(1, 2, 3).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(acc +=)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.reduce(_ + _) == 6) + } + + test("fold") { + val ds = Seq(1, 2, 3).toDS() + assert(ds.fold(0)(_ + _) == 6) + } + + test("groupBy function, keys") { + val ds = Seq(1, 2, 3, 4, 5).toDS() + val grouped = ds.groupBy(_ % 2) + checkAnswer( + grouped.keys, + 0, 1) + } + + test("groupBy function, mapGroups") { + val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS() + val grouped = ds.groupBy(_ % 2) + val agged = grouped.mapGroups { case (g, iter) => + val name = if (g == 0) "even" else "odd" + Iterator((name, iter.size)) + } + + checkAnswer( + agged, + ("even", 5), ("odd", 6)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..08496249c60cc0e55c07de1926e74caf796c8a50 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.postfixOps + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +case class ClassData(a: String, b: Int) + +class DatasetSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("toDS") { + val data = Seq(("a", 1) , ("b", 2), ("c", 3)) + checkAnswer( + data.toDS(), + data: _*) + } + + test("as case class / collect") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] + checkAnswer( + ds, + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + assert(ds.collect().head == ClassData("a", 1)) + } + + test("as case class - reordered fields by name") { + val ds = Seq((1, "a"), (2, "b"), (3, "c")).toDF("b", "a").as[ClassData] + assert(ds.collect() === Array(ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))) + } + + test("map") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.map(v => (v._1, v._2 + 1)), + ("a", 2), ("b", 3), ("c", 4)) + } + + test("select") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select(expr("_2 + 1").as[Int]), + 2, 3, 4) + } + + test("select 3") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("_2").as[Int], + expr("_2 + 1").as[Int]), + ("a", 1, 2), ("b", 2, 3), ("c", 3, 4)) + } + + test("filter") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.filter(_._1 == "b"), + ("b", 2)) + } + + test("foreach") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreach(v => acc += v._2) + assert(acc.value == 6) + } + + test("foreachPartition") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + val acc = sparkContext.accumulator(0) + ds.foreachPartition(_.foreach(v => acc += v._2)) + assert(acc.value == 6) + } + + test("reduce") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + assert(ds.reduce((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + } + + test("fold") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) + } + + test("groupBy function, keys") { + val ds = Seq(("a", 1), ("b", 1)).toDS() + val grouped = ds.groupBy(v => (1, v._2)) + checkAnswer( + grouped.keys, + (1, 1)) + } + + test("groupBy function, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy(v => (v._1, "word")) + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g._1, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index e3c5a426671d01a13df5ad6deedba8fb4f015859..aba567512fe328d45dd1f721e20846c8eaf91a94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ +import scala.reflect.runtime.universe._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder} abstract class QueryTest extends PlanTest { @@ -53,6 +55,12 @@ abstract class QueryTest extends PlanTest { } } + protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = { + checkAnswer( + ds.toDF(), + sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + } + /** * Runs the plan and makes sure the answer matches the expected result. * @param df the [[DataFrame]] to be executed