diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 40b9fc9534f445545fbeb5e4949d9f6269fe7c60..9de4ca71ff6d478feb22d8c8426973a34a814054 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1088,6 +1088,12 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return fromBytes(getBytes()); } + public UTF8String copy() { + byte[] bytes = new byte[numBytes]; + copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes); + return fromBytes(bytes); + } + @Override public int compareTo(@Nonnull final UTF8String other) { int len = Math.min(numBytes, other.numBytes); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 256f64e320be89f640543a5b50b51c68089d4f03..29110640d64f2a990ca9a7188097b856be427867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types.{DataType, Decimal, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class for row used internally in Spark SQL, which only contains the columns as @@ -33,6 +35,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def setNullAt(i: Int): Unit + /** + * Updates the value at column `i`. Note that after updating, the given value will be kept in this + * row, and the caller side should guarantee that this value won't be changed afterwards. + */ def update(i: Int, value: Any): Unit // default implementation (slow) @@ -58,7 +64,15 @@ abstract class InternalRow extends SpecializedGetters with Serializable { def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ - def anyNull: Boolean + def anyNull: Boolean = { + val len = numFields + var i = 0 + while (i < len) { + if (isNullAt(i)) { return true } + i += 1 + } + false + } /* ---------------------- utility methods for Scala ---------------------- */ @@ -94,4 +108,15 @@ object InternalRow { /** Returns an empty [[InternalRow]]. */ val empty = apply() + + /** + * Copies the given value if it's string/struct/array/map type. + */ + def copyValue(value: Any): Any = value match { + case v: UTF8String => v.copy() + case v: InternalRow => v.copy() + case v: ArrayData => v.copy() + case v: MapData => v.copy() + case _ => value + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 43df19ba009a831a38347c939c0e0b89162e7707..3862e64b9d828c40914401d731189101ffe37e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1047,7 +1047,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String final $rowClass $result = new $rowClass(${fieldsCasts.length}); final InternalRow $tmpRow = $c; $fieldsEvalCode - $evPrim = $result.copy(); + $evPrim = $result; """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala index 74e0b4691d4cc96281afc8628e78c72c8af22acb..75feaf670c84a53691ffdf582710a73c2a946e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificInternalRow.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ /** @@ -220,17 +219,6 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen override def isNullAt(i: Int): Boolean = values(i).isNull - override def copy(): InternalRow = { - val newValues = new Array[Any](values.length) - var i = 0 - while (i < values.length) { - newValues(i) = values(i).boxed - i += 1 - } - - new GenericInternalRow(newValues) - } - override protected def genericGet(i: Int): Any = values(i).boxed override def update(ordinal: Int, value: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 26cd9ab6653836dbb9d33deaa85644461b9445cb..0d2f9889a27d581121c63dd989894c36b6a28753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -52,7 +52,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator if (value != null) { - buffer += value + buffer += InternalRow.copyValue(value) } buffer } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index fffcc7c9ef53a2bbe3a984679450b11f01e9f259..7af49014358570c4a3b26f00519bb532fcb5b13a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -317,6 +317,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit @@ -326,6 +329,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac * * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. + * + * Note that, the input row may be produced by unsafe projection and it may not be safe to cache + * some fields of the input row, as the values can be changed unexpectedly. */ def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5158949b9562993c9da4db6b003d3256e09bbb5d..b15bf2ca7c11627781b4eaf23df009f6bf1e6249 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -408,9 +408,11 @@ class CodegenContext { dataType match { case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" - // The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes) - case StringType => s"$row.update($ordinal, $value.clone())" case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) + // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy + // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. + case StringType | _: StructType | _: ArrayType | _: MapType => + s"$row.update($ordinal, $value.copy())" case _ => s"$row.update($ordinal, $value)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index f708aeff2b146685ea1280d65a120e4fee02fc3a..dd0419d2286d162554d220cab8ef6d3ccc36a071 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -131,8 +131,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] case s: StructType => createCodeForStruct(ctx, input, s) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) - // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe. - case StringType => ExprCode("", "false", s"$input.clone()") case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case _ => ExprCode("", "false", input) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 751b821e1b009aecf018824a154b0c219dd98fad..65539a2f00e6c332cf3c468cef6faa5a57d9818e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -50,16 +50,6 @@ trait BaseGenericInternalRow extends InternalRow { override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - override def anyNull: Boolean = { - val len = numFields - var i = 0 - while (i < len) { - if (isNullAt(i)) { return true } - i += 1 - } - false - } - override def toString: String = { if (numFields == 0) { "[empty row]" @@ -79,6 +69,17 @@ trait BaseGenericInternalRow extends InternalRow { } } + override def copy(): GenericInternalRow = { + val len = numFields + val newValues = new Array[Any](len) + var i = 0 + while (i < len) { + newValues(i) = InternalRow.copyValue(genericGet(i)) + i += 1 + } + new GenericInternalRow(newValues) + } + override def equals(o: Any): Boolean = { if (!o.isInstanceOf[BaseGenericInternalRow]) { return false @@ -206,6 +207,4 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow override def setNullAt(i: Int): Unit = { values(i) = null} override def update(i: Int, value: Any): Unit = { values(i) = value } - - override def copy(): GenericInternalRow = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index dd660c80a9c3c1ad0b337e322fef44b2139cd383..9e39ed9c3a778e9bf69b64413356af9a71d83fcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData { def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) - override def copy(): ArrayData = new GenericArrayData(array.clone()) + override def copy(): ArrayData = { + val newValues = new Array[Any](array.length) + var i = 0 + while (i < array.length) { + newValues(i) = InternalRow.copyValue(array(i)) + i += 1 + } + new GenericArrayData(newValues) + } override def numElements(): Int = array.length diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index c9c9599e7f4633c575b0213c507d9cb3f2fd7cc9..25699de33d717364a113cec8a57588172fe096bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers { externalRow should be theSameInstanceAs externalRow.copy() } - it("copy should return same ref for internal rows") { - internalRow should be theSameInstanceAs internalRow.copy() - } - it("toSeq should not expose internal state for external rows") { val modifiedValues = modifyValues(externalRow.toSeq) externalRow.toSeq should not equal modifiedValues diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala deleted file mode 100644 index 25a675a90276dbd3d373eda3b87f18d61a49e7cd..0000000000000000000000000000000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MapDataSuite.scala +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions - -import scala.collection._ - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class MapDataSuite extends SparkFunSuite { - - test("inequality tests") { - def u(str: String): UTF8String = UTF8String.fromString(str) - - // test data - val testMap1 = Map(u("key1") -> 1) - val testMap2 = Map(u("key1") -> 1, u("key2") -> 2) - val testMap3 = Map(u("key1") -> 1) - val testMap4 = Map(u("key1") -> 1, u("key2") -> 2) - - // ArrayBasedMapData - val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) - val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) - val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) - val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) - assert(testArrayMap1 !== testArrayMap3) - assert(testArrayMap2 !== testArrayMap4) - - // UnsafeMapData - val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) - val row = new GenericInternalRow(1) - def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { - row.update(0, map) - val unsafeRow = unsafeConverter.apply(row) - unsafeRow.getMap(0).copy - } - assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) - assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 58ea5b9cb52d3ab227cb0457b352c0214683d3f2..0cd0d8859145f2ba94d8c8af6d33dd351934a3dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -172,4 +172,40 @@ class GeneratedProjectionSuite extends SparkFunSuite { assert(unsafe1 === unsafe3) assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) } + + test("MutableProjection should not cache content from the input row") { + val mutableProj = GenerateMutableProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val row = new GenericInternalRow(1) + mutableProj.target(row) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + mutableProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } + + test("SafeProjection should not cache content from the input row") { + val safeProj = GenerateSafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + + val unsafeProj = GenerateUnsafeProjection.generate( + Seq(BoundReference(0, new StructType().add("i", StringType), true))) + val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a")))) + + val row = safeProj.apply(unsafeRow) + assert(row.getStruct(0, 1).getString(0) == "a") + + // Even if the input row of the mutable projection has been changed, the target mutable row + // should keep same. + unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b")))) + assert(row.getStruct(0, 1).getString(0).toString == "a") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..9d285916bcf42cacb77a9a98b182a2991345961c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ComplexDataSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ComplexDataSuite extends SparkFunSuite { + def utf8(str: String): UTF8String = UTF8String.fromString(str) + + test("inequality tests for MapData") { + // test data + val testMap1 = Map(utf8("key1") -> 1) + val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + val testMap3 = Map(utf8("key1") -> 1) + val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2) + + // ArrayBasedMapData + val testArrayMap1 = ArrayBasedMapData(testMap1.toMap) + val testArrayMap2 = ArrayBasedMapData(testMap2.toMap) + val testArrayMap3 = ArrayBasedMapData(testMap3.toMap) + val testArrayMap4 = ArrayBasedMapData(testMap4.toMap) + assert(testArrayMap1 !== testArrayMap3) + assert(testArrayMap2 !== testArrayMap4) + + // UnsafeMapData + val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType))) + val row = new GenericInternalRow(1) + def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = { + row.update(0, map) + val unsafeRow = unsafeConverter.apply(row) + unsafeRow.getMap(0).copy + } + assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3)) + assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4)) + } + + test("GenericInternalRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericRow = genericRow.copy() + assert(copiedGenericRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedGenericRow.getString(0) == "a") + } + + test("SpecificMutableRow.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val mutableRow = new SpecificInternalRow(Seq(StringType)) + mutableRow(0) = unsafeRow.getUTF8String(0) + val copiedMutableRow = mutableRow.copy() + assert(copiedMutableRow.getString(0) == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied internal row should not be changed externally. + assert(copiedMutableRow.getString(0) == "a") + } + + test("GenericArrayData.copy return a new instance that is independent from the old one") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0))) + val copiedGenericArray = genericArray.copy() + assert(copiedGenericArray.getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied array data should not be changed externally. + assert(copiedGenericArray.getUTF8String(0).toString == "a") + } + + test("copy on nested complex type") { + val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true))) + val unsafeRow = project.apply(InternalRow(utf8("a"))) + + val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0)))) + val copied = arrayOfRow.copy() + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + project.apply(InternalRow(UTF8String.fromString("b"))) + // The copied data should not be changed externally. + assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a") + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e23a64350cbc52896f01cd101f7bc84e34ab1fe6..34dc3af9b85c85ea78369a74e1e0cd28fe211120 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -149,7 +149,7 @@ public final class ColumnarBatch { } else if (dt instanceof DoubleType) { row.setDouble(i, getDouble(i)); } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i)); + row.update(i, getUTF8String(i).copy()); } else if (dt instanceof BinaryType) { row.update(i, getBinary(i)); } else if (dt instanceof DecimalType) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index bea2dce1a7657c6367fbf1fd4c0685829e7205f7..a5a444b160c6325f961fd3a63bb511660594458a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -86,17 +86,6 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer - // This safe projection is used to turn the input row into safe row. This is necessary - // because the input row may be produced by unsafe projection in child operator and all the - // produced rows share one byte array. However, when we update the aggregate buffer according to - // the input row, we may cache some values from input row, e.g. `Max` will keep the max value from - // input row via MutableProjection, `CollectList` will keep all values in an array via - // ImperativeAggregate framework. These values may get changed unexpectedly if the underlying - // unsafe projection update the shared byte array. By applying a safe projection to the input row, - // we can cut down the connection from input row to the shared byte array, and thus it's safe to - // cache values from input row while updating the aggregation buffer. - private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) - protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -119,7 +108,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) + processRow(sortBasedAggregationBuffer, firstRowInNextGroup) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -130,7 +119,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, safeProj(currentRow)) + processRow(sortBasedAggregationBuffer, currentRow) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index d3fa0dcd2d7c31a5a872a20773a0829b3c687fb1..fc977f2fd553089c6b4e000912b3fdbd50c3aabe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -56,7 +56,6 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR // all other methods inherited from GenericMutableRow are not need override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException override def numFields: Int = throw new UnsupportedOperationException - override def copy(): InternalRow = throw new UnsupportedOperationException } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index 2195c6ea95948350d8cd754c9aefa29ca004d0d5..bc141b36e63b419399b3b35e449e074f0127967a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,13 +145,10 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that - // MutableProjection makes copies of the complex input objects it buffer. - val copy = input.copy() - updateProjection(join(buffer, copy)) + updateProjection(join(buffer, input)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, copy) + imperatives(i).update(buffer, input) i += 1 } }