Skip to content
Snippets Groups Projects
Commit 9d03ad91 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-9543][SQL] Add randomized testing for UnsafeKVExternalSorter.

The detailed approach is documented in UnsafeKVExternalSorterSuite.testKVSorter(), working as follows:

1. Create input by generating data randomly based on the given key/value schema (which is also randomly drawn from a list of candidate types)
2. Run UnsafeKVExternalSorter on the generated data
3. Collect the output from the sorter, and make sure the keys are sorted in ascending order
4. Sort the input by both key and value, and sort the sorter output also by both key and value. Compare the sorted input and sorted output together to make sure all the key/values match.
5. Check memory allocation to make sure there is no memory leak.

There is also a spill flag. When set to true, the sorter will spill probabilistically roughly every 100 records.

Author: Reynold Xin <rxin@databricks.com>

Closes #7873 from rxin/kvsorter-randomized-test and squashes the following commits:

a08c251 [Reynold Xin] Resource cleanup.
0488b5c [Reynold Xin] [SPARK-9543][SQL] Add randomized testing for UnsafeKVExternalSorter.
parent 0722f433
No related branches found
No related tags found
No related merge requests found
......@@ -24,6 +24,7 @@ import java.math.MathContext
import scala.util.Random
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
/**
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random
......@@ -106,6 +107,11 @@ object RandomDataGenerator {
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
case CalendarIntervalType => Some(() => {
val months = rand.nextInt(1000)
val ns = rand.nextLong()
new CalendarInterval(months, ns)
})
case DecimalType.Fixed(precision, scale) => Some(
() => BigDecimal.apply(
rand.nextLong() % math.pow(10, precision).toLong,
......
......@@ -56,11 +56,21 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
assert(leakedShuffleMemory === 0)
taskMemoryManager = null
}
TaskContext.unset()
}
test(name) {
taskMemoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
shuffleMemoryManager = new TestShuffleMemoryManager
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
partitionId = 0,
taskAttemptId = Random.nextInt(10000),
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
metricsSystem = null))
try {
f
} catch {
......@@ -163,14 +173,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
partitionId = 0,
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemoryManager,
metricsSystem = null))
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
......
......@@ -19,140 +19,136 @@ package org.apache.spark.sql.execution
import scala.util.Random
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{RowOrdering, UnsafeProjection}
import org.apache.spark._
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, RowOrdering, UnsafeProjection}
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark._
/**
* Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
*/
class UnsafeKVExternalSorterSuite extends SparkFunSuite {
test("sorting string key and int int value") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
val shuffleMemMgr = new TestShuffleMemoryManager
testKVSorter(new StructType, new StructType, spill = true)
testKVSorter(new StructType().add("c1", IntegerType), new StructType, spill = true)
testKVSorter(new StructType, new StructType().add("c1", IntegerType), spill = true)
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
partitionId = 0,
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
metricsSystem = null))
val keySchema = new StructType().add("a", StringType)
val valueSchema = new StructType().add("b", IntegerType).add("c", IntegerType)
val sorter = new UnsafeKVExternalSorter(
keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
16 * 1024)
val keyConverter = UnsafeProjection.create(keySchema)
val valueConverter = UnsafeProjection.create(valueSchema)
private val rand = new Random(42)
for (i <- 0 until 6) {
val keySchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, keyTypes)
val valueSchema = RandomDataGenerator.randomSchema(rand.nextInt(10) + 1, valueTypes)
testKVSorter(keySchema, valueSchema, spill = i > 3)
}
val rand = new Random(42)
val data = null +: Seq.fill[String](10) {
Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
}
/**
* Create a test case using randomly generated data for the given key and value schema.
*
* The approach works as follows:
*
* - Create input by randomly generating data based on the given schema
* - Run [[UnsafeKVExternalSorter]] on the generated data
* - Collect the output from the sorter, and make sure the keys are sorted in ascending order
* - Sort the input by both key and value, and sort the sorter output also by both key and value.
* Compare the sorted input and sorted output together to make sure all the key/values match.
*
* If spill is set to true, the sorter will spill probabilistically roughly every 100 records.
*/
private def testKVSorter(keySchema: StructType, valueSchema: StructType, spill: Boolean): Unit = {
val keySchemaStr = keySchema.map(_.dataType.simpleString).mkString("[", ",", "]")
val valueSchemaStr = valueSchema.map(_.dataType.simpleString).mkString("[", ",", "]")
test(s"kv sorting key schema $keySchemaStr and value schema $valueSchemaStr") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
val shuffleMemMgr = new TestShuffleMemoryManager
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
partitionId = 0,
taskAttemptId = 98456,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
metricsSystem = null))
// Create the data converters
val kExternalConverter = CatalystTypeConverters.createToCatalystConverter(keySchema)
val vExternalConverter = CatalystTypeConverters.createToCatalystConverter(valueSchema)
val kConverter = UnsafeProjection.create(keySchema)
val vConverter = UnsafeProjection.create(valueSchema)
val keyDataGen = RandomDataGenerator.forType(keySchema, nullable = false).get
val valueDataGen = RandomDataGenerator.forType(valueSchema, nullable = false).get
val input = Seq.fill(1024) {
val k = kConverter(kExternalConverter.apply(keyDataGen.apply()).asInstanceOf[InternalRow])
val v = vConverter(vExternalConverter.apply(valueDataGen.apply()).asInstanceOf[InternalRow])
(k.asInstanceOf[InternalRow].copy(), v.asInstanceOf[InternalRow].copy())
}
val inputRows = data.map { str =>
keyConverter.apply(InternalRow(UTF8String.fromString(str))).copy()
}
val sorter = new UnsafeKVExternalSorter(
keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr, 16 * 1024 * 1024)
var i = 0
data.foreach { str =>
if (str != null) {
val k = InternalRow(UTF8String.fromString(str))
val v = InternalRow(str.length, str.length + 1)
sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
} else {
val k = InternalRow(UTF8String.fromString(str))
val v = InternalRow(-1, -2)
sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
// Insert generated keys and values into the sorter
input.foreach { case (k, v) =>
sorter.insertKV(k.asInstanceOf[UnsafeRow], v.asInstanceOf[UnsafeRow])
// 1% chance we will spill
if (rand.nextDouble() < 0.01 && spill) {
shuffleMemMgr.markAsOutOfMemory()
sorter.closeCurrentPage()
}
}
if ((i % 100) == 0) {
shuffleMemMgr.markAsOutOfMemory()
sorter.closeCurrentPage()
// Collect the sorted output
val out = new scala.collection.mutable.ArrayBuffer[(InternalRow, InternalRow)]
val iter = sorter.sortedIterator()
while (iter.next()) {
out += Tuple2(iter.getKey.copy(), iter.getValue.copy())
}
i += 1
}
val out = new scala.collection.mutable.ArrayBuffer[InternalRow]
val iter = sorter.sortedIterator()
while (iter.next()) {
if (iter.getKey.getUTF8String(0) == null) {
withClue(s"for null key") {
assert(-1 === iter.getValue.getInt(0))
assert(-2 === iter.getValue.getInt(1))
}
} else {
val key = iter.getKey.getString(0)
withClue(s"for key $key") {
assert(key.length === iter.getValue.getInt(0))
assert(key.length + 1 === iter.getValue.getInt(1))
val keyOrdering = RowOrdering.forSchema(keySchema.map(_.dataType))
val valueOrdering = RowOrdering.forSchema(valueSchema.map(_.dataType))
val kvOrdering = new Ordering[(InternalRow, InternalRow)] {
override def compare(x: (InternalRow, InternalRow), y: (InternalRow, InternalRow)): Int = {
keyOrdering.compare(x._1, y._1) match {
case 0 => valueOrdering.compare(x._2, y._2)
case cmp => cmp
}
}
}
out += iter.getKey.copy()
}
assert(out === inputRows.sorted(RowOrdering.forSchema(keySchema.map(_.dataType))))
}
test("sorting arbitrary string data") {
// Calling this make sure we have block manager and everything else setup.
TestSQLContext
val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
val shuffleMemMgr = new TestShuffleMemoryManager
TaskContext.setTaskContext(new TaskContextImpl(
stageId = 0,
partitionId = 0,
taskAttemptId = 0,
attemptNumber = 0,
taskMemoryManager = taskMemMgr,
metricsSystem = null))
val keySchema = new StructType().add("a", StringType)
val valueSchema = new StructType().add("b", IntegerType)
val sorter = new UnsafeKVExternalSorter(
keySchema, valueSchema, SparkEnv.get.blockManager, shuffleMemMgr,
16 * 1024)
val keyConverter = UnsafeProjection.create(keySchema)
val valueConverter = UnsafeProjection.create(valueSchema)
val rand = new Random(42)
val data = Seq.fill(512) {
Seq.fill(rand.nextInt(100))(rand.nextPrintableChar()).mkString
}
// Testing to make sure output from the sorter is sorted by key
var prevK: InternalRow = null
out.zipWithIndex.foreach { case ((k, v), i) =>
if (prevK != null) {
assert(keyOrdering.compare(prevK, k) <= 0,
s"""
|key is not in sorted order:
|previous key: $prevK
|current key : $k
""".stripMargin)
}
prevK = k
}
var i = 0
data.foreach { str =>
val k = InternalRow(UTF8String.fromString(str))
val v = InternalRow(str.length)
sorter.insertKV(keyConverter.apply(k), valueConverter.apply(v))
// Testing to make sure the key/value in output matches input
assert(out.sorted(kvOrdering) === input.sorted(kvOrdering))
if ((i % 100) == 0) {
shuffleMemMgr.markAsOutOfMemory()
sorter.closeCurrentPage()
// Make sure there is no memory leak
val leakedUnsafeMemory: Long = taskMemMgr.cleanUpAllAllocatedMemory
if (shuffleMemMgr != null) {
val leakedShuffleMemory: Long = shuffleMemMgr.getMemoryConsumptionForThisTask()
assert(0L === leakedShuffleMemory)
}
i += 1
assert(0 === leakedUnsafeMemory)
TaskContext.unset()
}
val out = new scala.collection.mutable.ArrayBuffer[String]
val iter = sorter.sortedIterator()
while (iter.next()) {
assert(iter.getKey.getString(0).length === iter.getValue.getInt(0))
out += iter.getKey.getString(0)
}
assert(out === data.sorted)
}
}
......@@ -22,7 +22,7 @@ import org.junit.Test;
import static junit.framework.Assert.*;
import static org.apache.spark.unsafe.types.CalendarInterval.*;
public class IntervalSuite {
public class CalendarIntervalSuite {
@Test
public void equalsTest() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment