Skip to content
Snippets Groups Projects
Commit 31345fde authored by Sameer Agarwal's avatar Sameer Agarwal Committed by Wenchen Fan
Browse files

[SPARK-20451] Filter out nested mapType datatypes from sort order in randomSplit

## What changes were proposed in this pull request?

In `randomSplit`, It is possible that the underlying dataset doesn't guarantee the ordering of rows in its constituent partitions each time a split is materialized which could result in overlapping
splits.

To prevent this, as part of SPARK-12662, we explicitly sort each input partition to make the ordering deterministic. Given that `MapTypes` cannot be sorted this patch explicitly prunes them out from the sort order. Additionally, if the resulting sort order is empty, this patch then materializes the dataset to guarantee determinism.

## How was this patch tested?

Extended `randomSplit on reordered partitions` in `DataFrameStatSuite` to also test for dataframes with mapTypes nested mapTypes.

Author: Sameer Agarwal <sameerag@cs.berkeley.edu>

Closes #17751 from sameeragarwal/randomsplit2.
parent f44c8a84
No related branches found
No related tags found
No related merge requests found
...@@ -1726,15 +1726,23 @@ class Dataset[T] private[sql]( ...@@ -1726,15 +1726,23 @@ class Dataset[T] private[sql](
// It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its
// constituent partitions each time a split is materialized which could result in // constituent partitions each time a split is materialized which could result in
// overlapping splits. To prevent this, we explicitly sort each input partition to make the // overlapping splits. To prevent this, we explicitly sort each input partition to make the
// ordering deterministic. // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out
// MapType cannot be sorted. // from the sort order.
val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) val sortOrder = logicalPlan.output
.map(SortOrder(_, Ascending)), global = false, logicalPlan) .filter(attr => RowOrdering.isOrderable(attr.dataType))
.map(SortOrder(_, Ascending))
val plan = if (sortOrder.nonEmpty) {
Sort(sortOrder, global = false, logicalPlan)
} else {
// SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism
cache()
logicalPlan
}
val sum = weights.sum val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x => normalizedCumWeights.sliding(2).map { x =>
new Dataset[T]( new Dataset[T](
sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder)
}.toArray }.toArray
} }
......
...@@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { ...@@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
} }
test("randomSplit on reordered partitions") { test("randomSplit on reordered partitions") {
// This test ensures that randomSplit does not create overlapping splits even when the
// underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
// rows in each partition.
val data =
sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id")
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
assert(splits.length == 2, "wrong number of splits") def testNonOverlappingSplits(data: DataFrame): Unit = {
val splits = data.randomSplit(Array[Double](2, 3), seed = 1)
assert(splits.length == 2, "wrong number of splits")
// Verify that the splits span the entire dataset
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet)
// Verify that the splits span the entire dataset // Verify that the splits don't overlap
assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty)
// Verify that the splits don't overlap // Verify that the results are deterministic across multiple runs
assert(splits(0).intersect(splits(1)).collect().isEmpty) val firstRun = splits.toSeq.map(_.collect().toSeq)
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq)
assert(firstRun == secondRun)
}
// Verify that the results are deterministic across multiple runs // This test ensures that randomSplit does not create overlapping splits even when the
val firstRun = splits.toSeq.map(_.collect().toSeq) // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of
val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) // rows in each partition.
assert(firstRun == secondRun) val dataWithInts = sparkContext.parallelize(1 to 600, 2)
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int")
val dataWithMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Map(i -> i.toString)))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map")
val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2)
.map(i => (i, Array(Map(i -> i.toString))))
.mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps")
testNonOverlappingSplits(dataWithInts)
testNonOverlappingSplits(dataWithMaps)
testNonOverlappingSplits(dataWithArrayOfMaps)
} }
test("pearson correlation") { test("pearson correlation") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment