diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 9b6b5b8bd1a276fb2445e7008e9e77e8bde0593e..9013fd050b5f99535b6f53622de0c9026254b053 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -414,10 +414,6 @@ object ScalaReflection extends ScalaReflection { } else { val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - // `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here - // to trigger the type check. - extractorFor(inputObject, elementType, newPath) - MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 67518f52d4a58db01bef4c17de1726d80560a8af..d34ec9408ae1bd959c75d9e3cb5dc8820b191914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -193,7 +193,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor, input, et), + MapObjects(constructorFor(_), input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index e6ab9a31be59edee7774633426d5a9123ab7917c..b2facfda24446849985a3ab10408e4698173a004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -326,19 +326,28 @@ case class WrapOption(child: Expression) * A place holder for the loop variable used in [[MapObjects]]. This should never be constructed * manually, but will instead be passed into the provided lambda function. */ -case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends Expression { +case class LambdaVariable(value: String, isNull: String, dataType: DataType) extends LeafExpression + with Unevaluable { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException("Only calling gen() is supported.") + override def nullable: Boolean = true - override def children: Seq[Expression] = Nil - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = + override def gen(ctx: CodeGenContext): GeneratedExpressionCode = { GeneratedExpressionCode(code = "", value = value, isNull = isNull) + } +} - override def nullable: Boolean = false - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") +object MapObjects { + private val curId = new java.util.concurrent.atomic.AtomicInteger() + def apply( + function: Expression => Expression, + inputData: Expression, + elementType: DataType): MapObjects = { + val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() + val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + MapObjects(loopVar, function(loopVar), inputData) + } } /** @@ -349,20 +358,16 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext * The following collection ObjectTypes are currently supported: * Seq, Array, ArrayData, java.util.List * - * @param function A function that returns an expression, given an attribute that can be used - * to access the current value. This is does as a lambda function so that - * a unique attribute reference can be provided for each expression (thus allowing - * us to nest multiple MapObject calls). + * @param loopVar A place holder that used as the loop variable when iterate the collection, and + * used as input for the `lambdaFunction`. It also carries the element type info. + * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function + * to handle collection elements. * @param inputData An expression that when evaluted returns a collection object. - * @param elementType The type of element in the collection, expressed as a DataType. */ case class MapObjects( - function: AttributeReference => Expression, - inputData: Expression, - elementType: DataType) extends Expression { - - private lazy val loopAttribute = AttributeReference("loopVar", elementType)() - private lazy val completeFunction = function(loopAttribute) + loopVar: LambdaVariable, + lambdaFunction: Expression, + inputData: Expression) extends Expression { private def itemAccessorMethod(dataType: DataType): String => String = dataType match { case NullType => @@ -402,37 +407,23 @@ case class MapObjects( override def nullable: Boolean = true - override def children: Seq[Expression] = completeFunction :: inputData :: Nil + override def children: Seq[Expression] = lambdaFunction :: inputData :: Nil override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def dataType: DataType = ArrayType(completeFunction.dataType) + override def dataType: DataType = ArrayType(lambdaFunction.dataType) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) - val elementJavaType = ctx.javaType(elementType) + val elementJavaType = ctx.javaType(loopVar.dataType) val genInputData = inputData.gen(ctx) - - // Variables to hold the element that is currently being processed. - val loopValue = ctx.freshName("loopValue") - val loopIsNull = ctx.freshName("loopIsNull") - - val loopVariable = LambdaVariable(loopValue, loopIsNull, elementType) - val substitutedFunction = completeFunction transform { - case a: AttributeReference if a == loopAttribute => loopVariable - } - // A hack to run this through the analyzer (to bind extractions). - val boundFunction = - SimpleAnalyzer.execute(Project(Alias(substitutedFunction, "")() :: Nil, LocalRelation(Nil))) - .expressions.head.children.head - - val genFunction = boundFunction.gen(ctx) + val genFunction = lambdaFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") - val convertedType = ctx.boxedType(boundFunction.dataType) + val convertedType = ctx.boxedType(lambdaFunction.dataType) // Because of the way Java defines nested arrays, we have to handle the syntax specially. // Specifically, we have to insert the [$dataLength] in between the type and any extra nested @@ -446,9 +437,9 @@ case class MapObjects( } val loopNullCheck = if (primitiveElement) { - s"boolean $loopIsNull = ${genInputData.value}.isNullAt($loopIndex);" + s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" } else { - s"boolean $loopIsNull = ${genInputData.isNull} || $loopValue == null;" + s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } s""" @@ -464,11 +455,11 @@ case class MapObjects( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $elementJavaType $loopValue = + $elementJavaType ${loopVar.value} = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck - if ($loopIsNull) { + if (${loopVar.isNull}) { $convertedArray[$loopIndex] = null; } else { ${genFunction.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index d6ca138672ef194ab5cd9d81437004e6938d6b6e..7233e0f1b5baf0eec9ee6b13ebb60557c8be4979 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -145,6 +145,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { case class InnerClass(i: Int) productTest(InnerClass(1)) + encodeDecodeTest(Array(InnerClass(1)), "array of inner class") productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true))