diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c25161ee81b6614762c34ad2fc59b7fa39ba7a76..9cbb7c2ffdc764334fc2cb7855aa7cd2b25b57d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -146,6 +146,10 @@ trait ScalaReflection { * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. + * + * When used on a primitive type, the constructor will instead default to extracting the value + * from ordinal 0 (since there are no names to map to). The actual location can be moved by + * calling unbind/bind with a new schema. */ def constructorFor[T : TypeTag]: Expression = constructorFor(typeOf[T], None) @@ -159,8 +163,14 @@ trait ScalaReflection { .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal(ordinal: Int, dataType: DataType) = + path + .map(p => GetStructField(p, StructField(s"_$ordinal", dataType), ordinal)) + .getOrElse(BoundReference(ordinal, dataType, false)) + /** Returns the current path or throws an error. */ - def getPath = path.getOrElse(sys.error("Constructors must start at a class type")) + def getPath = path.getOrElse(BoundReference(0, dataTypeFor(tpe), true)) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => @@ -387,12 +397,17 @@ trait ScalaReflection { val className: String = t.erasure.typeSymbol.asClass.fullName val cls = Utils.classForName(className) - val arguments = params.head.map { p => + val arguments = params.head.zipWithIndex.map { case (p, i) => val fieldName = p.name.toString val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = dataTypeFor(fieldType) + val dataType = schemaFor(fieldType).dataType - constructorFor(fieldType, Some(addToPath(fieldName))) + // For tuples, we based grab the inner fields by ordinal instead of name. + if (className startsWith "scala.Tuple") { + constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + } else { + constructorFor(fieldType, Some(addToPath(fieldName))) + } } val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) @@ -413,7 +428,10 @@ trait ScalaReflection { /** Returns expressions for extracting all the fields from the given type. */ def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { ScalaReflectionLock.synchronized { - extractorFor(inputObject, typeTag[T].tpe).asInstanceOf[CreateNamedStruct] + extractorFor(inputObject, typeTag[T].tpe) match { + case s: CreateNamedStruct => s + case o => CreateNamedStruct(expressions.Literal("value") :: o :: Nil) + } } } @@ -602,6 +620,21 @@ trait ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) + case t if t <:< definitions.IntTpe => + BoundReference(0, IntegerType, false) + case t if t <:< definitions.LongTpe => + BoundReference(0, LongType, false) + case t if t <:< definitions.DoubleTpe => + BoundReference(0, DoubleType, false) + case t if t <:< definitions.FloatTpe => + BoundReference(0, FloatType, false) + case t if t <:< definitions.ShortTpe => + BoundReference(0, ShortType, false) + case t if t <:< definitions.ByteTpe => + BoundReference(0, ByteType, false) + case t if t <:< definitions.BooleanTpe => + BoundReference(0, BooleanType, false) + case other => throw new UnsupportedOperationException(s"Extractor for type $other is not supported") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala deleted file mode 100644 index b484b8fde63692b7c73cfedb4a69144d8b4dea9d..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, SimpleAnalyzer} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} -import org.apache.spark.sql.types.{ObjectType, StructType} - -/** - * A generic encoder for JVM objects. - * - * @param schema The schema after converting `T` to a Spark SQL row. - * @param extractExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object. - * @param clsTag A classtag for `T`. - */ -case class ClassEncoder[T]( - schema: StructType, - extractExpressions: Seq[Expression], - constructExpression: Expression, - clsTag: ClassTag[T]) - extends Encoder[T] { - - @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) - private val inputRow = new GenericMutableRow(1) - - @transient - private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) - private val dataType = ObjectType(clsTag.runtimeClass) - - override def toRow(t: T): InternalRow = { - inputRow(0) = t - extractProjection(inputRow) - } - - override def fromRow(row: InternalRow): T = { - constructProjection(row).get(0, dataType).asInstanceOf[T] - } - - override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { - val plan = Project(Alias(constructExpression, "object")() :: Nil, LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - val resolvedExpression = analyzedPlan.expressions.head.children.head - val boundExpression = BindReferences.bindReference(resolvedExpression, schema) - - copy(constructExpression = boundExpression) - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ClassEncoder[T] = { - val positionToAttribute = AttributeMap.toIndex(oldSchema) - val attributeToNewPosition = AttributeMap.byIndex(newSchema) - copy(constructExpression = constructExpression transform { - case r: BoundReference => - r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) - }) - } - - override def bindOrdinals(schema: Seq[Attribute]): ClassEncoder[T] = { - var remaining = schema - copy(constructExpression = constructExpression transform { - case u: UnresolvedAttribute => - val pos = remaining.head - remaining = remaining.drop(1) - pos - }) - } - - protected val attrs = extractExpressions.map(_.collect { - case a: Attribute => s"#${a.exprId}" - case b: BoundReference => s"[${b.ordinal}]" - }.headOption.getOrElse("")) - - - protected val schemaString = - schema - .zip(attrs) - .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") - - override def toString: String = s"class[$schemaString]" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala index efb872ddb81e52ba1ca35222aefb5af36979739c..329a132d3d8b241d9c2878b874c43769238adbc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/Encoder.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.catalyst.encoders + import scala.reflect.ClassTag -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types.StructType /** @@ -30,44 +29,11 @@ import org.apache.spark.sql.types.StructType * Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking * and reuse internal buffers to improve performance. */ -trait Encoder[T] { +trait Encoder[T] extends Serializable { /** Returns the schema of encoding this type of object as a Row. */ def schema: StructType /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] - - /** - * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to - * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should - * copy the result before making another call if required. - */ - def toRow(t: T): InternalRow - - /** - * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `bind` an encoder to a specific schema before you can call this function. - */ - def fromRow(row: InternalRow): T - - /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are bound to the - * given schema. - */ - def bind(schema: Seq[Attribute]): Encoder[T] - - /** - * Binds this encoder to the given schema positionally. In this binding, the first reference to - * any input is mapped to `schema(0)`, and so on for each input that is encountered. - */ - def bindOrdinals(schema: Seq[Attribute]): Encoder[T] - - /** - * Given an encoder that has already been bound to a given schema, returns a new encoder that - * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, - * when you are trying to use an encoder on grouping keys that were orriginally part of a larger - * row, but now you have projected out only the key expressions. - */ - def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[T] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala new file mode 100644 index 0000000000000000000000000000000000000000..c287aebeeee059ae107ddd5c6af7b2251d00f06f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.util.Utils + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types.{StructField, DataType, ObjectType, StructType} + +/** + * A factory for constructing encoders that convert objects and primitves to and from the + * internal row format using catalyst expressions and code generation. By default, the + * expressions used to retrieve values from an input row when producing an object will be created as + * follows: + * - Classes will have their sub fields extracted by name using [[UnresolvedAttribute]] expressions + * and [[UnresolvedExtractValue]] expressions. + * - Tuples will have their subfields extracted by position using [[BoundReference]] expressions. + * - Primitives will have their values extracted from the first ordinal with a schema that defaults + * to the name `value`. + */ +object ExpressionEncoder { + def apply[T : TypeTag](flat: Boolean = false): ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(typeTag[T].tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val extractExpression = ScalaReflection.extractorsFor[T](inputObject) + val constructExpression = ScalaReflection.constructorFor[T] + + new ExpressionEncoder[T]( + extractExpression.dataType, + flat, + extractExpression.flatten, + constructExpression, + ClassTag[T](cls)) + } + + /** + * Given a set of N encoders, constructs a new encoder that produce objects as items in an + * N-tuple. Note that these encoders should first be bound correctly to the combined input + * schema. + */ + def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = { + val schema = + StructType( + encoders.zipWithIndex.map { case (e, i) => StructField(s"_${i + 1}", e.schema)}) + val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") + val extractExpressions = encoders.map { + case e if e.flat => e.extractExpressions.head + case other => CreateStruct(other.extractExpressions) + } + val constructExpression = + NewInstance(cls, encoders.map(_.constructExpression), false, ObjectType(cls)) + + new ExpressionEncoder[Any]( + schema, + false, + extractExpressions, + constructExpression, + ClassTag.apply(cls)) + } + + /** A helper for producing encoders of Tuple2 from other encoders. */ + def tuple[T1, T2]( + e1: ExpressionEncoder[T1], + e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = + tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]] +} + +/** + * A generic encoder for JVM objects. + * + * @param schema The schema after converting `T` to a Spark SQL row. + * @param extractExpressions A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object. + * @param clsTag A classtag for `T`. + */ +case class ExpressionEncoder[T]( + schema: StructType, + flat: Boolean, + extractExpressions: Seq[Expression], + constructExpression: Expression, + clsTag: ClassTag[T]) + extends Encoder[T] { + + if (flat) require(extractExpressions.size == 1) + + @transient + private lazy val extractProjection = GenerateUnsafeProjection.generate(extractExpressions) + private val inputRow = new GenericMutableRow(1) + + @transient + private lazy val constructProjection = GenerateSafeProjection.generate(constructExpression :: Nil) + + /** + * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to + * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should + * copy the result before making another call if required. + */ + def toRow(t: T): InternalRow = { + inputRow(0) = t + extractProjection(inputRow) + } + + /** + * Returns an object of type `T`, extracting the required values from the provided row. Note that + * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * function. + */ + def fromRow(row: InternalRow): T = try { + constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] + } catch { + case e: Exception => + throw new RuntimeException(s"Error while decoding: $e\n${constructExpression.treeString}", e) + } + + /** + * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the + * given schema. + */ + def resolve(schema: Seq[Attribute]): ExpressionEncoder[T] = { + val plan = Project(Alias(constructExpression, "")() :: Nil, LocalRelation(schema)) + val analyzedPlan = SimpleAnalyzer.execute(plan) + copy(constructExpression = analyzedPlan.expressions.head.children.head) + } + + /** + * Returns a copy of this encoder where the expressions used to construct an object from an input + * row have been bound to the ordinals of the given schema. Note that you need to first call + * resolve before bind. + */ + def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + copy(constructExpression = BindReferences.bindReference(constructExpression, schema)) + } + + /** + * Replaces any bound references in the schema with the attributes at the corresponding ordinal + * in the provided schema. This can be used to "relocate" a given encoder to pull values from + * a different schema than it was initially bound to. It can also be used to assign attributes + * to ordinal based extraction (i.e. because the input data was a tuple). + */ + def unbind(schema: Seq[Attribute]): ExpressionEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(schema) + copy(constructExpression = constructExpression transform { + case b: BoundReference => positionToAttribute(b.ordinal) + }) + } + + /** + * Given an encoder that has already been bound to a given schema, returns a new encoder + * where the positions are mapped from `oldSchema` to `newSchema`. This can be used, for example, + * when you are trying to use an encoder on grouping keys that were originally part of a larger + * row, but now you have projected out only the key expressions. + */ + def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): ExpressionEncoder[T] = { + val positionToAttribute = AttributeMap.toIndex(oldSchema) + val attributeToNewPosition = AttributeMap.byIndex(newSchema) + copy(constructExpression = constructExpression transform { + case r: BoundReference => + r.copy(ordinal = attributeToNewPosition(positionToAttribute(r.ordinal))) + }) + } + + /** + * Returns a copy of this encoder where the expressions used to create an object given an + * input row have been modified to pull the object out from a nested struct, instead of the + * top level fields. + */ + def nested(input: Expression = BoundReference(0, schema, true)): ExpressionEncoder[T] = { + copy(constructExpression = constructExpression transform { + case u: Attribute if u != input => + UnresolvedExtractValue(input, Literal(u.name)) + case b: BoundReference if b != input => + GetStructField( + input, + StructField(s"i[${b.ordinal}]", b.dataType), + b.ordinal) + }) + } + + protected val attrs = extractExpressions.flatMap(_.collect { + case _: UnresolvedAttribute => "" + case a: Attribute => s"#${a.exprId}" + case b: BoundReference => s"[${b.ordinal}]" + }) + + protected val schemaString = + schema + .zip(attrs) + .map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ") + + override def toString: String = s"class[$schemaString]" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala deleted file mode 100644 index 34f5e6c030f5889ecf32ebb433b5b0156aeca208..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{typeTag, TypeTag} - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.types.{ObjectType, StructType} - -/** - * A factory for constructing encoders that convert Scala's product type to/from the Spark SQL - * internal binary representation. - */ -object ProductEncoder { - def apply[T <: Product : TypeTag]: ClassEncoder[T] = { - // We convert the not-serializable TypeTag into StructType and ClassTag. - val mirror = typeTag[T].mirror - val cls = mirror.runtimeClass(typeTag[T].tpe) - - val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpression = ScalaReflection.extractorsFor[T](inputObject) - val constructExpression = ScalaReflection.constructorFor[T] - - new ClassEncoder[T]( - extractExpression.dataType, - extractExpression.flatten, - constructExpression, - ClassTag[T](cls)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index e9cc00a2b64ce054899432b999f6f40f13916c77..0b42130a013b2575959d6169cafdeb3c06cd1523 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -31,13 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String * internal binary representation. */ object RowEncoder { - def apply(schema: StructType): ClassEncoder[Row] = { + def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) val extractExpressions = extractorsFor(inputObject, schema) val constructExpression = constructorFor(schema) - new ClassEncoder[Row]( + new ExpressionEncoder[Row]( schema, + flat = false, extractExpressions.asInstanceOf[CreateStruct].children, constructExpression, ClassTag(cls)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala similarity index 56% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala index 52f8383faca92fc43eeb0d3e1bafadbcc85a0702..d4642a500672ef82629f05a94d563be005a44ed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/PrimitiveEncoderSuite.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/package.scala @@ -15,29 +15,12 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.encoders +package org.apache.spark.sql.catalyst -import org.apache.spark.SparkFunSuite - -class PrimitiveEncoderSuite extends SparkFunSuite { - test("long encoder") { - val enc = new LongEncoder() - val row = enc.toRow(10) - assert(row.getLong(0) == 10) - assert(enc.fromRow(row) == 10) - } - - test("int encoder") { - val enc = new IntEncoder() - val row = enc.toRow(10) - assert(row.getInt(0) == 10) - assert(enc.fromRow(row) == 10) - } - - test("string encoder") { - val enc = new StringEncoder() - val row = enc.toRow("test") - assert(row.getString(0) == "test") - assert(enc.fromRow(row) == "test") +package object encoders { + private[sql] def encoderFor[A : Encoder]: ExpressionEncoder[A] = implicitly[Encoder[A]] match { + case e: ExpressionEncoder[A] => e + case _ => sys.error(s"Only expression encoders are supported today") } } + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala deleted file mode 100644 index a93f2d7c6115dad6363617123275827163b3bfbe..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/primitiveTypes.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.sql.types._ - -/** An encoder for primitive Long types. */ -case class LongEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Long] { - private val row = UnsafeRow.createFromByteArray(64, 1) - - override def clsTag: ClassTag[Long] = ClassTag.Long - override def schema: StructType = - StructType(StructField(fieldName, LongType) :: Nil) - - override def fromRow(row: InternalRow): Long = row.getLong(ordinal) - - override def toRow(t: Long): InternalRow = { - row.setLong(ordinal, t) - row - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[Long] = this - override def bind(schema: Seq[Attribute]): Encoder[Long] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Long] = this -} - -/** An encoder for primitive Integer types. */ -case class IntEncoder(fieldName: String = "value", ordinal: Int = 0) extends Encoder[Int] { - private val row = UnsafeRow.createFromByteArray(64, 1) - - override def clsTag: ClassTag[Int] = ClassTag.Int - override def schema: StructType = - StructType(StructField(fieldName, IntegerType) :: Nil) - - override def fromRow(row: InternalRow): Int = row.getInt(ordinal) - - override def toRow(t: Int): InternalRow = { - row.setInt(ordinal, t) - row - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[Int] = this - override def bind(schema: Seq[Attribute]): Encoder[Int] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[Int] = this -} - -/** An encoder for String types. */ -case class StringEncoder( - fieldName: String = "value", - ordinal: Int = 0) extends Encoder[String] { - - val record = new SpecificMutableRow(StringType :: Nil) - - @transient - lazy val projection = - GenerateUnsafeProjection.generate(BoundReference(0, StringType, true) :: Nil) - - override def schema: StructType = - StructType( - StructField("value", StringType, nullable = false) :: Nil) - - override def clsTag: ClassTag[String] = scala.reflect.classTag[String] - - - override final def fromRow(row: InternalRow): String = { - row.getString(ordinal) - } - - override final def toRow(value: String): InternalRow = { - val utf8String = UTF8String.fromString(value) - record(0) = utf8String - // TODO: this is a bit of a hack to produce UnsafeRows - projection(record) - } - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[String] = this - override def bind(schema: Seq[Attribute]): Encoder[String] = this - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[String] = this -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala deleted file mode 100644 index a48eeda7d2e6f034e6a5e006d6ec200bff31e5ff..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/tuples.scala +++ /dev/null @@ -1,173 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.encoders - - -import scala.reflect.ClassTag - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.types.{StructField, StructType} - -// Most of this file is codegen. -// scalastyle:off - -/** - * A set of composite encoders that take sub encoders and map each of their objects to a - * Scala tuple. Note that currently the implementation is fairly limited and only supports going - * from an internal row to a tuple. - */ -object TupleEncoder { - - /** Code generator for composite tuple encoders. */ - def main(args: Array[String]): Unit = { - (2 to 5).foreach { i => - val types = (1 to i).map(t => s"T$t").mkString(", ") - val tupleType = s"($types)" - val args = (1 to i).map(t => s"e$t: Encoder[T$t]").mkString(", ") - val fields = (1 to i).map(t => s"""StructField("_$t", e$t.schema)""").mkString(", ") - val fromRow = (1 to i).map(t => s"e$t.fromRow(row)").mkString(", ") - - println( - s""" - |class Tuple${i}Encoder[$types]($args) extends Encoder[$tupleType] { - | val schema = StructType(Array($fields)) - | - | def clsTag: ClassTag[$tupleType] = scala.reflect.classTag[$tupleType] - | - | def fromRow(row: InternalRow): $tupleType = { - | ($fromRow) - | } - | - | override def toRow(t: $tupleType): InternalRow = - | throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - | - | override def bind(schema: Seq[Attribute]): Encoder[$tupleType] = { - | this - | } - | - | override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[$tupleType] = - | throw new UnsupportedOperationException("Tuple Encoders only support bind.") - | - | - | override def bindOrdinals(schema: Seq[Attribute]): Encoder[$tupleType] = - | throw new UnsupportedOperationException("Tuple Encoders only support bind.") - |} - """.stripMargin) - } - } -} - -class Tuple2Encoder[T1, T2](e1: Encoder[T1], e2: Encoder[T2]) extends Encoder[(T1, T2)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema))) - - def clsTag: ClassTag[(T1, T2)] = scala.reflect.classTag[(T1, T2)] - - def fromRow(row: InternalRow): (T1, T2) = { - (e1.fromRow(row), e2.fromRow(row)) - } - - override def toRow(t: (T1, T2)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple3Encoder[T1, T2, T3](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3]) extends Encoder[(T1, T2, T3)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema))) - - def clsTag: ClassTag[(T1, T2, T3)] = scala.reflect.classTag[(T1, T2, T3)] - - def fromRow(row: InternalRow): (T1, T2, T3) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple4Encoder[T1, T2, T3, T4](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4]) extends Encoder[(T1, T2, T3, T4)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema))) - - def clsTag: ClassTag[(T1, T2, T3, T4)] = scala.reflect.classTag[(T1, T2, T3, T4)] - - def fromRow(row: InternalRow): (T1, T2, T3, T4) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3, T4)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} - - -class Tuple5Encoder[T1, T2, T3, T4, T5](e1: Encoder[T1], e2: Encoder[T2], e3: Encoder[T3], e4: Encoder[T4], e5: Encoder[T5]) extends Encoder[(T1, T2, T3, T4, T5)] { - val schema = StructType(Array(StructField("_1", e1.schema), StructField("_2", e2.schema), StructField("_3", e3.schema), StructField("_4", e4.schema), StructField("_5", e5.schema))) - - def clsTag: ClassTag[(T1, T2, T3, T4, T5)] = scala.reflect.classTag[(T1, T2, T3, T4, T5)] - - def fromRow(row: InternalRow): (T1, T2, T3, T4, T5) = { - (e1.fromRow(row), e2.fromRow(row), e3.fromRow(row), e4.fromRow(row), e5.fromRow(row)) - } - - override def toRow(t: (T1, T2, T3, T4, T5)): InternalRow = - throw new UnsupportedOperationException("Tuple Encoders only support fromRow.") - - override def bind(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = { - this - } - - override def rebind(oldSchema: Seq[Attribute], newSchema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") - - - override def bindOrdinals(schema: Seq[Attribute]): Encoder[(T1, T2, T3, T4, T5)] = - throw new UnsupportedOperationException("Tuple Encoders only support bind.") -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 21a55a53718410d5d70d6fbbf90b87034628a794..d2d3db0a4448474000b750474c0ea534dacb0dfd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Utils import org.apache.spark.sql.catalyst.plans._ @@ -450,8 +450,8 @@ case object OneRowRelation extends LeafNode { */ case class MapPartitions[T, U]( func: Iterator[T] => Iterator[U], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def missingInput: AttributeSet = AttributeSet.empty @@ -460,8 +460,8 @@ case class MapPartitions[T, U]( /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumn { def apply[T : Encoder, U : Encoder](func: T => U, child: LogicalPlan): AppendColumn[T, U] = { - val attrs = implicitly[Encoder[U]].schema.toAttributes - new AppendColumn[T, U](func, implicitly[Encoder[T]], implicitly[Encoder[U]], attrs, child) + val attrs = encoderFor[U].schema.toAttributes + new AppendColumn[T, U](func, encoderFor[T], encoderFor[U], attrs, child) } } @@ -472,8 +472,8 @@ object AppendColumn { */ case class AppendColumn[T, U]( func: T => U, - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns @@ -488,11 +488,11 @@ object MapGroups { child: LogicalPlan): MapGroups[K, T, U] = { new MapGroups( func, - implicitly[Encoder[K]], - implicitly[Encoder[T]], - implicitly[Encoder[U]], + encoderFor[K], + encoderFor[T], + encoderFor[U], groupingAttributes, - implicitly[Encoder[U]].schema.toAttributes, + encoderFor[U].schema.toAttributes, child) } } @@ -504,9 +504,9 @@ object MapGroups { */ case class MapGroups[K, T, U]( func: (K, Iterator[T]) => Iterator[U], - kEncoder: Encoder[K], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala similarity index 91% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 008d0bea8a941499438c0b47ce53e22a85203283..a374da4da1f081f75b61d761acc6ea74e5115bdb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -47,7 +47,16 @@ case class RepeatedData( case class SpecificCollection(l: List[Int]) -class ProductEncoderSuite extends SparkFunSuite { +class ExpressionEncoderSuite extends SparkFunSuite { + + encodeDecodeTest(1) + encodeDecodeTest(1L) + encodeDecodeTest(1.toDouble) + encodeDecodeTest(1.toFloat) + encodeDecodeTest(true) + encodeDecodeTest(false) + encodeDecodeTest(1.toShort) + encodeDecodeTest(1.toByte) encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) @@ -210,24 +219,24 @@ class ProductEncoderSuite extends SparkFunSuite { { (l, r) => l._2.toString == r._2.toString } /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ - protected def encodeDecodeTest[T <: Product : TypeTag](inputData: T) = + protected def encodeDecodeTest[T : TypeTag](inputData: T) = encodeDecodeTestCustom[T](inputData)((l, r) => l == r) /** * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it * matches the original. */ - protected def encodeDecodeTestCustom[T <: Product : TypeTag]( + protected def encodeDecodeTestCustom[T : TypeTag]( inputData: T)( c: (T, T) => Boolean) = { - test(s"encode/decode: $inputData") { - val encoder = try ProductEncoder[T] catch { + test(s"encode/decode: $inputData - ${inputData.getClass.getName}") { + val encoder = try ExpressionEncoder[T]() catch { case e: Exception => fail(s"Exception thrown generating encoder", e) } val convertedData = encoder.toRow(inputData) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.bind(schema) + val boundEncoder = encoder.resolve(schema).bind(schema) val convertedBack = try boundEncoder.fromRow(convertedData) catch { case e: Exception => fail( @@ -236,15 +245,19 @@ class ProductEncoderSuite extends SparkFunSuite { |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | - |Construct Expressions: - |${boundEncoder.constructExpression.treeString} + |Encoder: + |$boundEncoder | """.stripMargin, e) } if (!c(inputData, convertedBack)) { - val types = - convertedBack.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + val types = convertedBack match { + case c: Product => + c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") + case other => other.getClass.getName + } + val encodedData = try { convertedData.toSeq(encoder.schema).zip(encoder.schema).map { @@ -269,11 +282,7 @@ class ProductEncoderSuite extends SparkFunSuite { |${encoder.schema.treeString} | |Extract Expressions: - |${boundEncoder.extractExpressions.map(_.treeString).mkString("\n")} - | - |Construct Expressions: - |${boundEncoder.constructExpression.treeString} - | + |$boundEncoder """.stripMargin) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 32d9b0b1d9888f062018fa11a471719c0d1cb306..aa817a037ef5e93e6088b86391095799b9e0b129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -267,7 +267,7 @@ class DataFrame private[sql]( * @since 1.6.0 */ @Experimental - def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, queryExecution) + def as[U : Encoder]: Dataset[U] = new Dataset[U](sqlContext, logicalPlan) /** * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 96213c7630400c2e93f77a98a1b4c13517e7be1d..e0ab5f593e933212327adc75b5e1cbfdf52d0f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types.StructType @@ -53,15 +54,21 @@ import org.apache.spark.sql.types.StructType * @since 1.6.0 */ @Experimental -class Dataset[T] private[sql]( +class Dataset[T] private( @transient val sqlContext: SQLContext, - @transient val queryExecution: QueryExecution)( - implicit val encoder: Encoder[T]) extends Serializable { + @transient val queryExecution: QueryExecution, + unresolvedEncoder: Encoder[T]) extends Serializable { + + /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ + private[sql] implicit val encoder: ExpressionEncoder[T] = unresolvedEncoder match { + case e: ExpressionEncoder[T] => e.resolve(queryExecution.analyzed.output) + case _ => throw new IllegalArgumentException("Only expression encoders are currently supported") + } private implicit def classTag = encoder.clsTag private[sql] def this(sqlContext: SQLContext, plan: LogicalPlan)(implicit encoder: Encoder[T]) = - this(sqlContext, new QueryExecution(sqlContext, plan)) + this(sqlContext, new QueryExecution(sqlContext, plan), encoder) /** Returns the schema of the encoded form of the objects in this [[Dataset]]. */ def schema: StructType = encoder.schema @@ -76,7 +83,9 @@ class Dataset[T] private[sql]( * TODO: document binding rules * @since 1.6.0 */ - def as[U : Encoder]: Dataset[U] = new Dataset(sqlContext, queryExecution)(implicitly[Encoder[U]]) + def as[U : Encoder]: Dataset[U] = { + new Dataset(sqlContext, queryExecution, encoderFor[U]) + } /** * Applies a logical alias to this [[Dataset]] that can be used to disambiguate columns that have @@ -103,7 +112,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def rdd: RDD[T] = { - val tEnc = implicitly[Encoder[T]] + val tEnc = encoderFor[T] val input = queryExecution.analyzed.output queryExecution.toRdd.mapPartitions { iter => val bound = tEnc.bind(input) @@ -150,9 +159,9 @@ class Dataset[T] private[sql]( sqlContext, MapPartitions[T, U]( func, - implicitly[Encoder[T]], - implicitly[Encoder[U]], - implicitly[Encoder[U]].schema.toAttributes, + encoderFor[T], + encoderFor[U], + encoderFor[U].schema.toAttributes, logicalPlan)) } @@ -209,8 +218,8 @@ class Dataset[T] private[sql]( val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( - implicitly[Encoder[K]].bindOrdinals(withGroupingKey.newColumns), - implicitly[Encoder[T]].bind(inputPlan.output), + encoderFor[K].resolve(withGroupingKey.newColumns), + encoderFor[T].bind(inputPlan.output), executed, inputPlan.output, withGroupingKey.newColumns) @@ -220,6 +229,18 @@ class Dataset[T] private[sql]( * Typed Relational * * ****************** */ + /** + * Selects a set of column based expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + * @group dfops + * @since 1.3.0 + */ + // Copied from Dataframe to make sure we don't have invalid overloads. + @scala.annotation.varargs + def select(cols: Column*): DataFrame = toDF().select(cols: _*) + /** * Returns a new [[Dataset]] by computing the given [[Column]] expression for each element. * @@ -233,88 +254,64 @@ class Dataset[T] private[sql]( new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) } - // Codegen - // scalastyle:off - - /** sbt scalaShell; println(Seq(1).toDS().genSelect) */ - private def genSelect: String = { - (2 to 5).map { n => - val types = (1 to n).map(i =>s"U$i").mkString(", ") - val args = (1 to n).map(i => s"c$i: TypedColumn[U$i]").mkString(", ") - val encoders = (1 to n).map(i => s"c$i.encoder").mkString(", ") - val schema = (1 to n).map(i => s"""Alias(c$i.expr, "_$i")()""").mkString(" :: ") - s""" - |/** - | * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. - | * @since 1.6.0 - | */ - |def select[$types]($args): Dataset[($types)] = { - | implicit val te = new Tuple${n}Encoder($encoders) - | new Dataset[($types)](sqlContext, - | Project( - | $schema :: Nil, - | logicalPlan)) - |} - | - """.stripMargin - }.mkString("\n") + /** + * Internal helper function for building typed selects that return tuples. For simplicity and + * code reuse, we do this without the help of the type system and then use helper functions + * that cast appropriately for the user facing interface. + */ + protected def selectUntyped(columns: TypedColumn[_]*): Dataset[_] = { + val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } + val unresolvedPlan = Project(aliases, logicalPlan) + val execution = new QueryExecution(sqlContext, unresolvedPlan) + // Rebind the encoders to the nested schema that will be produced by the select. + val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { + case (e: ExpressionEncoder[_], a) if !e.flat => + e.nested(a.toAttribute).resolve(execution.analyzed.output) + case (e, a) => + e.unbind(a.toAttribute :: Nil).resolve(execution.analyzed.output) + } + new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = { - implicit val te = new Tuple2Encoder(c1.encoder, c2.encoder) - new Dataset[(U1, U2)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Nil, - logicalPlan)) - } - - + def select[U1, U2](c1: TypedColumn[U1], c2: TypedColumn[U2]): Dataset[(U1, U2)] = + selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2, U3](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = { - implicit val te = new Tuple3Encoder(c1.encoder, c2.encoder, c3.encoder) - new Dataset[(U1, U2, U3)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Nil, - logicalPlan)) - } - - + def select[U1, U2, U3]( + c1: TypedColumn[U1], + c2: TypedColumn[U2], + c3: TypedColumn[U3]): Dataset[(U1, U2, U3)] = + selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2, U3, U4](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = { - implicit val te = new Tuple4Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder) - new Dataset[(U1, U2, U3, U4)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Nil, - logicalPlan)) - } - - + def select[U1, U2, U3, U4]( + c1: TypedColumn[U1], + c2: TypedColumn[U2], + c3: TypedColumn[U3], + c4: TypedColumn[U4]): Dataset[(U1, U2, U3, U4)] = + selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 */ - def select[U1, U2, U3, U4, U5](c1: TypedColumn[U1], c2: TypedColumn[U2], c3: TypedColumn[U3], c4: TypedColumn[U4], c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = { - implicit val te = new Tuple5Encoder(c1.encoder, c2.encoder, c3.encoder, c4.encoder, c5.encoder) - new Dataset[(U1, U2, U3, U4, U5)](sqlContext, - Project( - Alias(c1.expr, "_1")() :: Alias(c2.expr, "_2")() :: Alias(c3.expr, "_3")() :: Alias(c4.expr, "_4")() :: Alias(c5.expr, "_5")() :: Nil, - logicalPlan)) - } - - // scalastyle:on + def select[U1, U2, U3, U4, U5]( + c1: TypedColumn[U1], + c2: TypedColumn[U2], + c3: TypedColumn[U3], + c4: TypedColumn[U4], + c5: TypedColumn[U5]): Dataset[(U1, U2, U3, U4, U5)] = + selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /* **************** * * Set operations * @@ -360,6 +357,48 @@ class Dataset[T] private[sql]( */ def subtract(other: Dataset[T]): Dataset[T] = withPlan[T](other)(Except) + /* ****** * + * Joins * + * ****** */ + + /** + * Joins this [[Dataset]] returning a [[Tuple2]] for each pair where `condition` evaluates to + * true. + * + * This is similar to the relation `join` function with one important difference in the + * result schema. Since `joinWith` preserves objects present on either side of the join, the + * result schema is similarly nested into a tuple under the column names `_1` and `_2`. + * + * This type of join can be useful both for preserving type-safety with the original object + * types as well as working with relational data where either side of the join has column + * names in common. + */ + def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = { + val left = this.logicalPlan + val right = other.logicalPlan + + val leftData = this.encoder match { + case e if e.flat => Alias(left.output.head, "_1")() + case _ => Alias(CreateStruct(left.output), "_1")() + } + val rightData = other.encoder match { + case e if e.flat => Alias(right.output.head, "_2")() + case _ => Alias(CreateStruct(right.output), "_2")() + } + val leftEncoder = + if (encoder.flat) encoder else encoder.nested(leftData.toAttribute) + val rightEncoder = + if (other.encoder.flat) other.encoder else other.encoder.nested(rightData.toAttribute) + implicit val tuple2Encoder: Encoder[(T, U)] = + ExpressionEncoder.tuple(leftEncoder, rightEncoder) + + withPlan[(T, U)](other) { (left, right) => + Project( + leftData :: rightData :: Nil, + Join(left, right, Inner, Some(condition.expr))) + } + } + /* ************************** * * Gather to Driver Actions * * ************************** */ @@ -380,13 +419,10 @@ class Dataset[T] private[sql]( private[sql] def logicalPlan = queryExecution.analyzed private[sql] def withPlan(f: LogicalPlan => LogicalPlan): Dataset[T] = - new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan))) + new Dataset[T](sqlContext, sqlContext.executePlan(f(logicalPlan)), encoder) private[sql] def withPlan[R : Encoder]( other: Dataset[_])( f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R]( - sqlContext, - sqlContext.executePlan( - f(logicalPlan, other.logicalPlan))) + new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5e7198f974389d3ec2c0554892a6d82eff099472..2cb94430e6178e57e0e7d22f668eaec599d7536d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -34,7 +34,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, Encoder} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} @@ -491,7 +491,7 @@ class SQLContext private[sql]( def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { - val enc = implicitly[Encoder[T]] + val enc = encoderFor[T] val attributes = enc.schema.toAttributes val encoded = data.map(d => enc.toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index af8474df0de808b1ea6ffd23adf53cb70c7b3432..f460a86414c4159f3506dd0c6a6bb60db36a5b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,11 +37,16 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]() - implicit def newIntEncoder: Encoder[Int] = new IntEncoder() - implicit def newLongEncoder: Encoder[Long] = new LongEncoder() - implicit def newStringEncoder: Encoder[String] = new StringEncoder() + implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true) + implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) + implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true) + implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true) + implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true) + implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true) + implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) + implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(s)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2bb3dba5bd2baf6f7c58d132a714590c67f32fe4..89938471ee381aea994b7491688a8728ea46b26a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ @@ -319,8 +319,8 @@ case class OutputFaker(output: Seq[Attribute], child: SparkPlan) extends SparkPl */ case class MapPartitions[T, U]( func: Iterator[T] => Iterator[U], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -337,8 +337,8 @@ case class MapPartitions[T, U]( */ case class AppendColumns[T, U]( func: T => U, - tEncoder: Encoder[T], - uEncoder: Encoder[U], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { @@ -363,9 +363,9 @@ case class AppendColumns[T, U]( */ case class MapGroups[K, T, U]( func: (K, Iterator[T]) => Iterator[U], - kEncoder: Encoder[K], - tEncoder: Encoder[T], - uEncoder: Encoder[U], + kEncoder: ExpressionEncoder[K], + tEncoder: ExpressionEncoder[T], + uEncoder: ExpressionEncoder[U], groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 08496249c60cc0e55c07de1926e74caf796c8a50..aebb390a1d15de736e5903ba619929d619a6efa6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { data: _*) } + test("as tuple") { + val data = Seq(("a", 1), ("b", 2)).toDF("a", "b") + checkAnswer( + data.as[(String, Int)], + ("a", 1), ("b", 2)) + } + test("as case class / collect") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData] checkAnswer( @@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } - test("select 3") { + test("select 2") { val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() checkAnswer( ds.select( expr("_1").as[String], - expr("_2").as[Int], - expr("_2 + 1").as[Int]), - ("a", 1, 2), ("b", 2, 3), ("c", 3, 4)) + expr("_2").as[Int]) : Dataset[(String, Int)], + ("a", 1), ("b", 2), ("c", 3)) + } + + test("select 2, primitive and tuple") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("struct(_2, _2)").as[(Int, Int)]), + ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3))) + } + + test("select 2, primitive and class") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkAnswer( + ds.select( + expr("_1").as[String], + expr("named_struct('a', _1, 'b', _2)").as[ClassData]), + ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) + } + + test("select 2, primitive and class, fields reordered") { + val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS() + checkDecoding( + ds.select( + expr("_1").as[String], + expr("named_struct('b', _2, 'a', _1)").as[ClassData]), + ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 3))) } test("filter") { @@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6)) } + test("joinWith, flat schema") { + val ds1 = Seq(1, 2, 3).toDS().as("a") + val ds2 = Seq(1, 2).toDS().as("b") + + checkAnswer( + ds1.joinWith(ds2, $"a.value" === $"b.value"), + (1, 1), (2, 2)) + } + + test("joinWith, expression condition") { + val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + val ds2 = Seq(("a", 1), ("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"_1" === $"a"), + (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2))) + } + + test("joinWith tuple with primitive, expression") { + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = Seq(("a", 1), ("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"_2"), + (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2))) + } + + test("joinWith class with primitive, toDF") { + val ds1 = Seq(1, 1, 2).toDS() + val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS() + + checkAnswer( + ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", $"_2.b"), + Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil) + } + + test("multi-level joinWith") { + val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a") + val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b") + val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c") + + checkAnswer( + ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, $"ab._1._2" === $"c._2"), + ((("a", 1), ("a", 1)), ("a", 1)), + ((("b", 2), ("b", 2)), ("b", 2))) + + } + test("groupBy function, keys") { val ds = Seq(("a", 1), ("b", 1)).toDS() val grouped = ds.groupBy(v => (1, v._2)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index aba567512fe328d45dd1f721e20846c8eaf91a94..73e02eb0d957481479db119f5042fcd0388eb413 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import scala.reflect.runtime.universe._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder} +import org.apache.spark.sql.catalyst.encoders.Encoder abstract class QueryTest extends PlanTest { @@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest { } } - protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: T*): Unit = { + /** + * Evaluates a dataset to make sure that the result of calling collect matches the given + * expected answer. + * - Special handling is done based on whether the query plan should be expected to return + * the results in sorted order. + * - This function also checks to make sure that the schema for serializing the expected answer + * matches that produced by the dataset (i.e. does manual construction of object match + * the constructed encoder for cases like joins, etc). Note that this means that it will fail + * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead + * which performs a subset of the checks done by this function. + */ + protected def checkAnswer[T : Encoder]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + + checkDecoding(ds, expectedAnswer: _*) + } + + protected def checkDecoding[T]( + ds: => Dataset[T], + expectedAnswer: T*): Unit = { + val decoded = try ds.collect().toSet catch { + case e: Exception => + fail( + s""" + |Exception collecting dataset as objects + |${ds.encoder} + |${ds.encoder.constructExpression.treeString} + |${ds.queryExecution} + """.stripMargin, e) + } + + if (decoded != expectedAnswer.toSet) { + fail( + s"""Decoded objects do not match expected objects: + |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted} + |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted} + |${ds.encoder.constructExpression.treeString} + """.stripMargin) + } } /**