Skip to content
Snippets Groups Projects
Commit e0b20f9f authored by Liwei Lin's avatar Liwei Lin Committed by Herman van Hovell
Browse files

[SPARK-17061][SPARK-17093][SQL] MapObjects` should make copies of unsafe-backed data

## What changes were proposed in this pull request?

Currently `MapObjects` does not make copies of unsafe-backed data, leading to problems like [SPARK-17061](https://issues.apache.org/jira/browse/SPARK-17061) [SPARK-17093](https://issues.apache.org/jira/browse/SPARK-17093).

This patch makes `MapObjects` make copies of unsafe-backed data.

Generated code - prior to this patch:
```java
...
/* 295 */ if (isNull12) {
/* 296 */   convertedArray1[loopIndex1] = null;
/* 297 */ } else {
/* 298 */   convertedArray1[loopIndex1] = value12;
/* 299 */ }
...
```

Generated code - after this patch:
```java
...
/* 295 */ if (isNull12) {
/* 296 */   convertedArray1[loopIndex1] = null;
/* 297 */ } else {
/* 298 */   convertedArray1[loopIndex1] = value12 instanceof UnsafeRow? value12.copy() : value12;
/* 299 */ }
...
```

## How was this patch tested?

Add a new test case which would fail without this patch.

Author: Liwei Lin <lwlin7@gmail.com>

Closes #14698 from lw-lin/mapobjects-copy.
parent 2bcd5d5c
No related branches found
No related tags found
No related merge requests found
......@@ -494,6 +494,16 @@ case class MapObjects private(
s"$seq == null ? $array[$loopIndex] : $seq.apply($loopIndex)"
}
// Make a copy of the data if it's unsafe-backed
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
s"$value instanceof ${clazz.getSimpleName}? ${value}.copy() : $value"
val genFunctionValue = lambdaFunction.dataType match {
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
case _ => genFunction.value
}
val loopNullCheck = inputDataType match {
case _: ArrayType => s"$loopIsNull = ${genInputData.value}.isNullAt($loopIndex);"
// The element of primitive array will never be null.
......@@ -521,7 +531,7 @@ case class MapObjects private(
if (${genFunction.isNull}) {
$convertedArray[$loopIndex] = null;
} else {
$convertedArray[$loopIndex] = ${genFunction.value};
$convertedArray[$loopIndex] = $genFunctionValue;
}
$loopIndex += 1;
......
......@@ -136,7 +136,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {
// some expression is reusing variable names across different instances.
// This behavior is tested in ExpressionEvalHelperSuite.
val plan = generateProject(
GenerateUnsafeProjection.generate(
UnsafeProjection.create(
Alias(expression, s"Optimized($expression)1")() ::
Alias(expression, s"Optimized($expression)2")() :: Nil),
expression)
......
......@@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}
......@@ -32,4 +34,36 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val invoke = Invoke(inputObject, "_2", IntegerType)
checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)
}
test("MapObjects should make copies of unsafe-backed data") {
// test UnsafeRow-backed data
val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
val structExpected = new GenericArrayData(
Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
checkEvalutionWithUnsafeProjection(
structEncoder.serializer.head, structExpected, structInputRow)
// test UnsafeArray-backed data
val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
val arrayExpected = new GenericArrayData(
Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
checkEvalutionWithUnsafeProjection(
arrayEncoder.serializer.head, arrayExpected, arrayInputRow)
// test UnsafeMap-backed data
val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
val mapInputRow = InternalRow.fromSeq(Seq(Array(
Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
val mapExpected = new GenericArrayData(Seq(
new ArrayBasedMapData(
new GenericArrayData(Array(1, 2)),
new GenericArrayData(Array(100, 200))),
new ArrayBasedMapData(
new GenericArrayData(Array(3, 4)),
new GenericArrayData(Array(300, 400)))))
checkEvalutionWithUnsafeProjection(
mapEncoder.serializer.head, mapExpected, mapInputRow)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment