diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala index f3a10638717756d05a02167db718c4e99c32b6d9..54096f18cbea1d332c352c1298f050ce4c92b613 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ClassEncoder.scala @@ -48,20 +48,12 @@ case class ClassEncoder[T]( private val dataType = ObjectType(clsTag.runtimeClass) override def toRow(t: T): InternalRow = { - if (t == null) { - null - } else { - inputRow(0) = t - extractProjection(inputRow) - } + inputRow(0) = t + extractProjection(inputRow) } override def fromRow(row: InternalRow): T = { - if (row eq null) { - null.asInstanceOf[T] - } else { - constructProjection(row).get(0, dataType).asInstanceOf[T] - } + constructProjection(row).get(0, dataType).asInstanceOf[T] } override def bind(schema: Seq[Attribute]): ClassEncoder[T] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3e74aabd078df682492978dfec28c56bc553e0bb..5142856afdcacb7c458fadb1f3062cab416e86f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -26,8 +26,11 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +/** + * A factory for constructing encoders that convert external row to/from the Spark SQL + * internal binary representation. + */ object RowEncoder { - def apply(schema: StructType): ClassEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) @@ -136,7 +139,7 @@ object RowEncoder { constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) ) } - CreateRow(fields) + CreateExternalRow(fields) } private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { @@ -195,7 +198,7 @@ object RowEncoder { Literal.create(null, externalDataTypeFor(f.dataType)), constructorFor(getField(input, i, f.dataType), f.dataType)) } - CreateRow(convertedFields) + CreateExternalRow(convertedFields) } private def getField( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 8fc00ad1bcb047bb349597894e5a672c1d28c57a..b42d6c5c1e14e1aa5f44a1f1edbd97a55f36cba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -456,7 +456,13 @@ case class MapObjects( } } -case class CreateRow(children: Seq[Expression]) extends Expression { +/** + * Constructs a new external row, using the result of evaluating the specified expressions + * as content. + * + * @param children A list of expression to use as content of the external row. + */ +case class CreateExternalRow(children: Seq[Expression]) extends Expression { override def dataType: DataType = ObjectType(classOf[Row]) override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6041b62b74bddc2b15f2fec72605b8328df02f14..e8301e8e06b520a63df360f16ed6abd1358af772 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -73,7 +73,7 @@ class RowEncoderSuite extends SparkFunSuite { private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema) - val inputGenerator = RandomDataGenerator.forType(schema).get + val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get var input: Row = null try {