From 4eb41879ce774dec1d16b2281ab1fbf41f9d418a Mon Sep 17 00:00:00 2001 From: Wenchen Fan <wenchen@databricks.com> Date: Sat, 1 Jul 2017 09:25:29 +0800 Subject: [PATCH] [SPARK-17528][SQL] data should be copied properly before saving into InternalRow ## What changes were proposed in this pull request? For performance reasons, `UnsafeRow.getString`, `getStruct`, etc. return a "pointer" that points to a memory region of this unsafe row. This makes the unsafe projection a little dangerous, because all of its output rows share one instance. When we implement SQL operators, we should be careful to not cache the input rows because they may be produced by unsafe projection from child operator and thus its content may change overtime. However, when we updating values of InternalRow(e.g. in mutable projection and safe projection), we only copy UTF8String, we should also copy InternalRow, ArrayData and MapData. This PR fixes this, and also fixes the copy of vairous InternalRow, ArrayData and MapData implementations. ## How was this patch tested? new regression tests Author: Wenchen Fan <wenchen@databricks.com> Closes #18483 from cloud-fan/fix-copy. --- .../apache/spark/unsafe/types/UTF8String.java | 6 + .../spark/sql/catalyst/InternalRow.scala | 27 ++++- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../expressions/SpecificInternalRow.scala | 12 -- .../expressions/aggregate/collect.scala | 2 +- .../expressions/aggregate/interfaces.scala | 6 + .../expressions/codegen/CodeGenerator.scala | 6 +- .../codegen/GenerateSafeProjection.scala | 2 - .../spark/sql/catalyst/expressions/rows.scala | 23 ++-- .../sql/catalyst/util/GenericArrayData.scala | 10 +- .../scala/org/apache/spark/sql/RowTest.scala | 4 - .../catalyst/expressions/MapDataSuite.scala | 57 ---------- .../codegen/GeneratedProjectionSuite.scala | 36 ++++++ .../sql/catalyst/util/ComplexDataSuite.scala | 107 ++++++++++++++++++ .../execution/vectorized/ColumnarBatch.java | 2 +- .../SortBasedAggregationIterator.scala | 15 +-- .../columnar/GenerateColumnAccessor.scala | 1 - .../execution/window/AggregateProcessor.scala | 7 +- 18 files changed, 212 insertions(+), 113 deletions(-) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 40b9fc9534..9de4ca71ff 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1088,6 +1088,12 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return fromBytes(getBytes()); } + public UTF8String copy() { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return fromBytes(bytes); + } + @Override public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 256f64e320..29110640d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class for row used internally in Spark SQL, which only contains the columns as @@ -33,6 +35,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def setNullAt(i: Int): Unit + /** + * Updates the value at column `i`. Note that after updating, the given value will be kept in this + * row, and the caller side should guarantee that this value won't be changed afterwards. + */ def update(i: Int, value: Any): Unit // default implementation (slow) @@ -58,7 +64,15 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } /* ---------------------- utility methods for Scala ---------------------- */ @@ -94,4 +108,15 @@ object InternalRow { /** Returns an empty [[InternalRow]]. */ val empty = apply() + + /** + * Copies the given value if it's string/struct/array/map type. + */ + def copyValue(value: Any): Any = value match { + case v: UTF8String => v.copy() + case v: InternalRow => v.copy() + case v: ArrayData => v.copy() + case v: MapData => v.copy() + case _ => value + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 43df19ba00..3862e64b9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1047,7 +1047,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String final $rowClass $result = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpRow = $c; $fieldsEvalCode - $evPrim = $result.copy(); + $evPrim = $result; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 74e0b4691d..75feaf670c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -220,17 +219,6 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen override def isNullAt(i: Int): Boolean = values(i).isNull - override def copy(): InternalRow = { - val newValues = new Array[Any](values.length) - var i = 0 - while (i < values.length) { - newValues(i) = values(i).boxed - i += 1 - } - - new GenericInternalRow(newValues) - } - override protected def genericGet(i: Int): Any = values(i).boxed override def update(ordinal: Int, value: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 26cd9ab665..0d2f9889a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -52,7 +52,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator if (value != null) { - buffer += value + buffer += InternalRow.copyValue(value) } buffer } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index fffcc7c9ef..7af4901435 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -317,6 +317,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit @@ -326,6 +329,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5158949b95..b15bf2ca7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -408,9 +408,11 @@ class CodegenContext { dataType match { case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) - case StringType => s"$row.update($ordinal, $value.clone())" case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f708aeff2b..dd0419d228 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -131,8 +131,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) - // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. - case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => ExprCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 751b821e1b..65539a2f00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -50,16 +50,6 @@ trait BaseGenericInternalRow extends InternalRow { override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - override def anyNull: Boolean = { - val len = numFields - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } - override def toString: String = { if (numFields == 0) { "[empty row]" @@ -79,6 +69,17 @@ trait BaseGenericInternalRow extends InternalRow { } } + override def copy(): GenericInternalRow = { + val len = numFields + val newValues = new Array[Any](len) + var i = 0 + while (i < len) { + newValues(i) = InternalRow.copyValue(genericGet(i)) + i += 1 + } + new GenericInternalRow(newValues) + } + override def equals(o: Any): Boolean = { if (!o.isInstanceOf[BaseGenericInternalRow]) { return false @@ -206,6 +207,4 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } - - override def copy(): GenericInternalRow = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index dd660c80a9..9e39ed9c3a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) - override def copy(): ArrayData = new GenericArrayData(array.clone()) + override def copy(): ArrayData = { + val newValues = new Array[Any](array.length) + var i = 0 + while (i < array.length) { + newValues(i) = InternalRow.copyValue(array(i)) + i += 1 + } + new GenericArrayData(newValues) + } override def numElements(): Int = array.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index c9c9599e7f..25699de33d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers { externalRow should be theSameInstanceAs externalRow.copy() } - it("copy should return same ref for internal rows") { - internalRow should be theSameInstanceAs internalRow.copy() - } - it("toSeq should not expose internal state for external rows") { val modifiedValues = modifyValues(externalRow.toSeq) externalRow.toSeq should not equal modifiedValues diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala deleted file mode 100644 index 25a675a902..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.collection._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class MapDataSuite extends SparkFunSuite { - - test("inequality tests") { - def u(str: String): UTF8String = UTF8String.fromString(str) - - // test data - val testMap1 = Map(u("key1") -> 1) - val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) - val testMap3 = Map(u("key1") -> 1) - val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) - - // ArrayBasedMapData - val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) - val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) - val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) - val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) - assert(testArrayMap1 !== testArrayMap3) - assert(testArrayMap2 !== testArrayMap4) - - // UnsafeMapData - val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) - val row = new GenericInternalRow(1) - def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { - row.update(0, map) - val unsafeRow = unsafeConverter.apply(row) - unsafeRow.getMap(0).copy - } - assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) - assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 58ea5b9cb5..0cd0d88591 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -172,4 +172,40 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafe1 === unsafe3) assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) } + + test("MutableProjection should not cache content from the input row") { + val mutableProj = GenerateMutableProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val row = new GenericInternalRow(1) + mutableProj.target(row) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + mutableProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } + + test("SafeProjection should not cache content from the input row") { + val safeProj = GenerateSafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + val row = safeProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala new file mode 100644 index 0000000000..9d285916bc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ComplexDataSuite extends SparkFunSuite { + def utf8(str: String): UTF8String = UTF8String.fromString(str) + + test("inequality tests for MapData") { + // test data + val testMap1 = Map(utf8("key1") -> 1) + val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + val testMap3 = Map(utf8("key1") -> 1) + val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + + // ArrayBasedMapData + val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) + val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) + val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) + val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) + assert(testArrayMap1 !== testArrayMap3) + assert(testArrayMap2 !== testArrayMap4) + + // UnsafeMapData + val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) + val row = new GenericInternalRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) + assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) + } + + test("GenericInternalRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericRow = genericRow.copy() + assert(copiedGenericRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedGenericRow.getString(0) == "a") + } + + test("SpecificMutableRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val mutableRow = new SpecificInternalRow(Seq(StringType)) + mutableRow(0) = unsafeRow.getUTF8String(0) + val copiedMutableRow = mutableRow.copy() + assert(copiedMutableRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedMutableRow.getString(0) == "a") + } + + test("GenericArrayData.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericArray = genericArray.copy() + assert(copiedGenericArray.getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied array data should not be changed externally. + assert(copiedGenericArray.getUTF8String(0).toString == "a") + } + + test("copy on nested complex type") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0)))) + val copied = arrayOfRow.copy() + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied data should not be changed externally. + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e23a64350c..34dc3af9b8 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -149,7 +149,7 @@ public final class ColumnarBatch { } else if (dt instanceof DoubleType) { row.setDouble(i, getDouble(i)); } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i)); + row.update(i, getUTF8String(i).copy()); } else if (dt instanceof BinaryType) { row.update(i, getBinary(i)); } else if (dt instanceof DecimalType) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index bea2dce1a7..a5a444b160 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -86,17 +86,6 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer - // This safe projection is used to turn the input row into safe row. This is necessary - // because the input row may be produced by unsafe projection in child operator and all the - // produced rows share one byte array. However, when we update the aggregate buffer according to - // the input row, we may cache some values from input row, e.g. `Max` will keep the max value from - // input row via MutableProjection, `CollectList` will keep all values in an array via - // ImperativeAggregate framework. These values may get changed unexpectedly if the underlying - // unsafe projection update the shared byte array. By applying a safe projection to the input row, - // we can cut down the connection from input row to the shared byte array, and thus it's safe to - // cache values from input row while updating the aggregation buffer. - private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) - protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -119,7 +108,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -130,7 +119,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, safeProj(currentRow)) + processRow(sortBasedAggregationBuffer, currentRow) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index d3fa0dcd2d..fc977f2fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -56,7 +56,6 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR // all other methods inherited from GenericMutableRow are not need override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException override def numFields: Int = throw new UnsupportedOperationException - override def copy(): InternalRow = throw new UnsupportedOperationException } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index 2195c6ea95..bc141b36e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,13 +145,10 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that - // MutableProjection makes copies of the complex input objects it buffer. - val copy = input.copy() - updateProjection(join(buffer, copy)) + updateProjection(join(buffer, input)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, copy) + imperatives(i).update(buffer, input) i += 1 } } -- GitLab