From 31345fde82ada1f8bb12807b250b04726a1f6aa6 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal <sameerag@cs.berkeley.edu> Date: Tue, 25 Apr 2017 13:05:20 +0800 Subject: [PATCH] [SPARK-20451] Filter out nested mapType datatypes from sort order in randomSplit ## What changes were proposed in this pull request? In `randomSplit`, It is possible that the underlying dataset doesn't guarantee the ordering of rows in its constituent partitions each time a split is materialized which could result in overlapping splits. To prevent this, as part of SPARK-12662, we explicitly sort each input partition to make the ordering deterministic. Given that `MapTypes` cannot be sorted this patch explicitly prunes them out from the sort order. Additionally, if the resulting sort order is empty, this patch then materializes the dataset to guarantee determinism. ## How was this patch tested? Extended `randomSplit on reordered partitions` in `DataFrameStatSuite` to also test for dataframes with mapTypes nested mapTypes. Author: Sameer Agarwal <sameerag@cs.berkeley.edu> Closes #17751 from sameeragarwal/randomsplit2. --- .../scala/org/apache/spark/sql/Dataset.scala | 18 +++++--- .../apache/spark/sql/DataFrameStatSuite.scala | 43 ++++++++++++------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c6dcd93bbd..06dd550071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1726,15 +1726,23 @@ class Dataset[T] private[sql]( // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the - // ordering deterministic. - // MapType cannot be sorted. - val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) - .map(SortOrder(_, Ascending)), global = false, logicalPlan) + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + val plan = if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, logicalPlan) + } else { + // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism + cache() + logicalPlan + } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) }.toArray } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 97890a035a..dd118f88e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val data = - sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - // Verify that the splits don't overlap - assert(splits(0).intersect(splits(1)).collect().isEmpty) + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") { -- GitLab