Skip to content
Snippets Groups Projects
Commit ce7ddabb authored by Yin Huai's avatar Yin Huai Committed by Michael Armbrust
Browse files

[SPARK-6368][SQL] Build a specialized serializer for Exchange operator.

JIRA: https://issues.apache.org/jira/browse/SPARK-6368

Author: Yin Huai <yhuai@databricks.com>

Closes #5497 from yhuai/serializer2 and squashes the following commits:

da562c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2
50e0c3d [Yin Huai] When no filed is emitted to shuffle, use SparkSqlSerializer for now.
9f1ed92 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2
6d07678 [Yin Huai] Address comments.
4273b8c [Yin Huai] Enabled SparkSqlSerializer2.
09e587a [Yin Huai] Remove TODO.
791b96a [Yin Huai] Use UTF8String.
60a1487 [Yin Huai] Merge remote-tracking branch 'upstream/master' into serializer2
3e09655 [Yin Huai] Use getAs for Date column.
43b9fb4 [Yin Huai] Test.
8297732 [Yin Huai] Fix test.
c9373c8 [Yin Huai] Support DecimalType.
2379eeb [Yin Huai] ASF header.
39704ab [Yin Huai] Specialized serializer for Exchange.
parent 517bdf36
No related branches found
No related tags found
No related merge requests found
......@@ -64,6 +64,8 @@ private[spark] object SQLConf {
// Set to false when debugging requires the ability to look at invalid query plans.
val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
......@@ -147,6 +149,8 @@ private[sql] class SQLConf extends Serializable {
*/
private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, "false").toBoolean
private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean
/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1
......
......@@ -19,13 +19,15 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner}
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.serializer.Serializer
import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.MutablePair
object Exchange {
......@@ -77,9 +79,48 @@ case class Exchange(
}
}
override def execute(): RDD[Row] = attachTree(this , "execute") {
lazy val sparkConf = child.sqlContext.sparkContext.getConf
@transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
def serializer(
keySchema: Array[DataType],
valueSchema: Array[DataType],
numPartitions: Int): Serializer = {
// In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
// through write(key) and then write(value) instead of write((key, value)). Because
// SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
// it when spillToMergeableFile in ExternalSorter will be used.
// So, we will not use SparkSqlSerializer2 when
// - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
// then the bypassMergeThreshold; or
// - newOrdering is defined.
val cannotUseSqlSerializer2 =
(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
// It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true.
val noField =
(keySchema == null || keySchema.length == 0) &&
(valueSchema == null || valueSchema.length == 0)
val useSqlSerializer2 =
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
!cannotUseSqlSerializer2 && // Safe to use Serializer2.
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
!noField
val serializer = if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.")
new SparkSqlSerializer2(keySchema, valueSchema)
} else {
logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf)
}
serializer
}
override def execute(): RDD[Row] = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) =>
// TODO: Eliminate redundant expressions in grouping key and value.
......@@ -111,7 +152,10 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Row, Row](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
shuffled.map(_._2)
case RangePartitioning(sortingExpressions, numPartitions) =>
......@@ -134,7 +178,9 @@ case class Exchange(
} else {
new ShuffledRDD[Row, Null, Null](rdd, part)
}
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
val keySchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, null, numPartitions))
shuffled.map(_._1)
case SinglePartition =>
......@@ -152,7 +198,8 @@ case class Exchange(
}
val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(null, valueSchema, 1))
shuffled.map(_._2)
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import java.io._
import java.math.{BigDecimal, BigInteger}
import java.nio.ByteBuffer
import java.sql.Timestamp
import scala.reflect.ClassTag
import org.apache.spark.serializer._
import org.apache.spark.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
import org.apache.spark.sql.types._
/**
* The serialization stream for [[SparkSqlSerializer2]]. It assumes that the object passed in
* its `writeObject` are [[Product2]]. The serialization functions for the key and value of the
* [[Product2]] are constructed based on their schemata.
* The benefit of this serialization stream is that compared with general-purpose serializers like
* Kryo and Java serializer, it can significantly reduce the size of serialized and has a lower
* allocation cost, which can benefit the shuffle operation. Right now, its main limitations are:
* 1. It does not support complex types, i.e. Map, Array, and Struct.
* 2. It assumes that the objects passed in are [[Product2]]. So, it cannot be used when
* [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort operation is used because
* the objects passed in the serializer are not in the type of [[Product2]]. Also also see
* the comment of the `serializer` method in [[Exchange]] for more information on it.
*/
private[sql] class Serializer2SerializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
out: OutputStream)
extends SerializationStream with Logging {
val rowOut = new DataOutputStream(out)
val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]]
writeKey(kv._1)
writeValue(kv._2)
this
}
def flush(): Unit = {
rowOut.flush()
}
def close(): Unit = {
rowOut.close()
}
}
/**
* The corresponding deserialization stream for [[Serializer2SerializationStream]].
*/
private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType],
valueSchema: Array[DataType],
in: InputStream)
extends DeserializationStream with Logging {
val rowIn = new DataInputStream(new BufferedInputStream(in))
val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
def readObject[T: ClassTag](): T = {
readKey()
readValue()
(key, value).asInstanceOf[T]
}
def close(): Unit = {
rowIn.close()
}
}
private[sql] class ShuffleSerializerInstance(
keySchema: Array[DataType],
valueSchema: Array[DataType])
extends SerializerInstance {
def serialize[T: ClassTag](t: T): ByteBuffer =
throw new UnsupportedOperationException("Not supported.")
def deserialize[T: ClassTag](bytes: ByteBuffer): T =
throw new UnsupportedOperationException("Not supported.")
def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException("Not supported.")
def serializeStream(s: OutputStream): SerializationStream = {
new Serializer2SerializationStream(keySchema, valueSchema, s)
}
def deserializeStream(s: InputStream): DeserializationStream = {
new Serializer2DeserializationStream(keySchema, valueSchema, s)
}
}
/**
* SparkSqlSerializer2 is a special serializer that creates serialization function and
* deserialization function based on the schema of data. It assumes that values passed in
* are key/value pairs and values returned from it are also key/value pairs.
* The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`.
*/
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
extends Serializer
with Logging
with Serializable{
def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
}
private[sql] object SparkSqlSerializer2 {
final val NULL = 0
final val NOT_NULL = 1
/**
* Check if rows with the given schema can be serialized with ShuffleSerializer.
*/
def support(schema: Array[DataType]): Boolean = {
if (schema == null) return true
var i = 0
while (i < schema.length) {
schema(i) match {
case udt: UserDefinedType[_] => return false
case array: ArrayType => return false
case map: MapType => return false
case struct: StructType => return false
case _ =>
}
i += 1
}
return true
}
/**
* The util function to create the serialization function based on the given schema.
*/
def createSerializationFunction(schema: Array[DataType], out: DataOutputStream): Row => Unit = {
(row: Row) =>
// If the schema is null, the returned function does nothing when it get called.
if (schema != null) {
var i = 0
while (i < schema.length) {
schema(i) match {
// When we write values to the underlying stream, we also first write the null byte
// first. Then, if the value is not null, we write the contents out.
case NullType => // Write nothing.
case BooleanType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeBoolean(row.getBoolean(i))
}
case ByteType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeByte(row.getByte(i))
}
case ShortType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeShort(row.getShort(i))
}
case IntegerType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeInt(row.getInt(i))
}
case LongType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeLong(row.getLong(i))
}
case FloatType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeFloat(row.getFloat(i))
}
case DoubleType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeDouble(row.getDouble(i))
}
case decimal: DecimalType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
val value = row.apply(i).asInstanceOf[Decimal]
val javaBigDecimal = value.toJavaBigDecimal
// First, write out the unscaled value.
val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray
out.writeInt(bytes.length)
out.write(bytes)
// Then, write out the scale.
out.writeInt(javaBigDecimal.scale())
}
case DateType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
out.writeInt(row.getAs[Int](i))
}
case TimestampType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
val timestamp = row.getAs[java.sql.Timestamp](i)
val time = timestamp.getTime
val nanos = timestamp.getNanos
out.writeLong(time - (nanos / 1000000)) // Write the milliseconds value.
out.writeInt(nanos) // Write the nanoseconds part.
}
case StringType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
val bytes = row.getAs[UTF8String](i).getBytes
out.writeInt(bytes.length)
out.write(bytes)
}
case BinaryType =>
if (row.isNullAt(i)) {
out.writeByte(NULL)
} else {
out.writeByte(NOT_NULL)
val bytes = row.getAs[Array[Byte]](i)
out.writeInt(bytes.length)
out.write(bytes)
}
}
i += 1
}
}
}
/**
* The util function to create the deserialization function based on the given schema.
*/
def createDeserializationFunction(
schema: Array[DataType],
in: DataInputStream,
mutableRow: SpecificMutableRow): () => Unit = {
() => {
// If the schema is null, the returned function does nothing when it get called.
if (schema != null) {
var i = 0
while (i < schema.length) {
schema(i) match {
// When we read values from the underlying stream, we also first read the null byte
// first. Then, if the value is not null, we update the field of the mutable row.
case NullType => mutableRow.setNullAt(i) // Read nothing.
case BooleanType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setBoolean(i, in.readBoolean())
}
case ByteType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setByte(i, in.readByte())
}
case ShortType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setShort(i, in.readShort())
}
case IntegerType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setInt(i, in.readInt())
}
case LongType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setLong(i, in.readLong())
}
case FloatType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setFloat(i, in.readFloat())
}
case DoubleType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.setDouble(i, in.readDouble())
}
case decimal: DecimalType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
// First, read in the unscaled value.
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
val unscaledVal = new BigInteger(bytes)
// Then, read the scale.
val scale = in.readInt()
// Finally, create the Decimal object and set it in the row.
mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale)))
}
case DateType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
mutableRow.update(i, in.readInt())
}
case TimestampType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
val time = in.readLong() // Read the milliseconds value.
val nanos = in.readInt() // Read the nanoseconds part.
val timestamp = new Timestamp(time)
timestamp.setNanos(nanos)
mutableRow.update(i, timestamp)
}
case StringType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
mutableRow.update(i, UTF8String(bytes))
}
case BinaryType =>
if (in.readByte() == NULL) {
mutableRow.setNullAt(i)
} else {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
mutableRow.update(i, bytes)
}
}
i += 1
}
}
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution
import java.sql.{Timestamp, Date}
import org.scalatest.{FunSuite, BeforeAndAfterAll}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.ShuffleDependency
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
class SparkSqlSerializer2DataTypeSuite extends FunSuite {
// Make sure that we will not use serializer2 for unsupported data types.
def checkSupported(dataType: DataType, isSupported: Boolean): Unit = {
val testName =
s"${if (dataType == null) null else dataType.toString} is " +
s"${if (isSupported) "supported" else "unsupported"}"
test(testName) {
assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported)
}
}
checkSupported(null, isSupported = true)
checkSupported(NullType, isSupported = true)
checkSupported(BooleanType, isSupported = true)
checkSupported(ByteType, isSupported = true)
checkSupported(ShortType, isSupported = true)
checkSupported(IntegerType, isSupported = true)
checkSupported(LongType, isSupported = true)
checkSupported(FloatType, isSupported = true)
checkSupported(DoubleType, isSupported = true)
checkSupported(DateType, isSupported = true)
checkSupported(TimestampType, isSupported = true)
checkSupported(StringType, isSupported = true)
checkSupported(BinaryType, isSupported = true)
checkSupported(DecimalType(10, 5), isSupported = true)
checkSupported(DecimalType.Unlimited, isSupported = true)
// For now, ArrayType, MapType, and StructType are not supported.
checkSupported(ArrayType(DoubleType, true), isSupported = false)
checkSupported(ArrayType(StringType, false), isSupported = false)
checkSupported(MapType(IntegerType, StringType, true), isSupported = false)
checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), isSupported = false)
checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), isSupported = false)
// UDTs are not supported right now.
checkSupported(new MyDenseVectorUDT, isSupported = false)
}
abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll {
var allColumns: String = _
val serializerClass: Class[Serializer] =
classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]]
var numShufflePartitions: Int = _
var useSerializer2: Boolean = _
override def beforeAll(): Unit = {
numShufflePartitions = conf.numShufflePartitions
useSerializer2 = conf.useSqlSerializer2
sql("set spark.sql.useSerializer2=true")
val supportedTypes =
Seq(StringType, BinaryType, NullType, BooleanType,
ByteType, ShortType, IntegerType, LongType,
FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
DateType, TimestampType)
val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
StructField(s"col$index", dataType, true)
}
allColumns = fields.map(_.name).mkString(",")
val schema = StructType(fields)
// Create a RDD with all data types supported by SparkSqlSerializer2.
val rdd =
sparkContext.parallelize((1 to 1000), 10).map { i =>
Row(
s"str${i}: test serializer2.",
s"binary${i}: test serializer2.".getBytes("UTF-8"),
null,
i % 2 == 0,
i.toByte,
i.toShort,
i,
Long.MaxValue - i.toLong,
(i + 0.25).toFloat,
(i + 0.75),
BigDecimal(Long.MaxValue.toString + ".12345"),
new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
new Date(i),
new Timestamp(i))
}
createDataFrame(rdd, schema).registerTempTable("shuffle")
super.beforeAll()
}
override def afterAll(): Unit = {
dropTempTable("shuffle")
sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
sql(s"set spark.sql.useSerializer2=$useSerializer2")
super.afterAll()
}
def checkSerializer[T <: Serializer](
executedPlan: SparkPlan,
expectedSerializerClass: Class[T]): Unit = {
executedPlan.foreach {
case exchange: Exchange =>
val shuffledRDD = exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]]
val dependency = shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
val serializerNotSetMessage =
s"Expected $expectedSerializerClass as the serializer of Exchange. " +
s"However, the serializer was not set."
val serializer = dependency.serializer.getOrElse(fail(serializerNotSetMessage))
assert(serializer.getClass === expectedSerializerClass)
case _ => // Ignore other nodes.
}
}
test("key schema and value schema are not nulls") {
val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
table("shuffle").collect())
}
test("value schema is null") {
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
assert(
df.map(r => r.getString(0)).collect().toSeq ===
table("shuffle").select("col0").map(r => r.getString(0)).collect().sorted.toSeq)
}
test("no map output field") {
val df = sql(s"SELECT 1 + 1 FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, classOf[SparkSqlSerializer])
}
}
/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = {
super.beforeAll()
// Sort merge will not be triggered.
sql("set spark.sql.shuffle.partitions = 200")
}
test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
}
}
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
// We are expecting SparkSqlSerializer.
override val serializerClass: Class[Serializer] =
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
override def beforeAll(): Unit = {
super.beforeAll()
// To trigger the sort merge.
sql("set spark.sql.shuffle.partitions = 201")
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment