diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 786b97ad7b9ece4e5c2d1c646286db415ca26b8a..c156b03cdb7c4f2511cdab1ca84c4cac7d021ef4 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T * A sampler for sampling with replacement, based on values drawn from Poisson distribution. * * @param fraction the sampling fraction (with replacement) + * @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low. * @tparam T item type */ @DeveloperApi -class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] { +class PoissonSampler[T: ClassTag]( + fraction: Double, + useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] { + + def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true) /** Epsilon slop to avoid failure from floating point jitter. */ require( @@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] override def sample(items: Iterator[T]): Iterator[T] = { if (fraction <= 0.0) { Iterator.empty - } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + } else if (useGapSamplingIfPossible && + fraction <= RandomSampler.defaultMaxGapSamplingFraction) { + new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) } else { - items.flatMap { item => { + items.flatMap { item => val count = rng.sample() if (count == 0) Iterator.empty else Iterator.fill(count)(item) - }} + } } } - override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction) + override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 0680f31d40f6db6509d520b8a8c3df1d8b65db32..c5d1ed0937b19de5ffc25e5b89be159bc9c08fca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow @@ -30,6 +30,7 @@ import org.apache.spark.sql.metric.SQLMetrics import org.apache.spark.sql.types.StructType import org.apache.spark.util.collection.ExternalSorter import org.apache.spark.util.collection.unsafe.sort.PrefixComparator +import org.apache.spark.util.random.PoissonSampler import org.apache.spark.util.{CompletionIterator, MutablePair} import org.apache.spark.{HashPartitioner, SparkEnv} @@ -130,12 +131,21 @@ case class Sample( { override def output: Seq[Attribute] = child.output - // TODO: How to pick seed? + override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows + override def canProcessUnsafeRows: Boolean = true + override def canProcessSafeRows: Boolean = true + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { - child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed) + // Disable gap sampling since the gap sampling method buffers two rows internally, + // requiring us to copy the row, which is more expensive than the random number generator. + new PartitionwiseSampledRDD[InternalRow, InternalRow]( + child.execute(), + new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false), + preservesPartitioning = true, + seed) } else { - child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed) + child.execute().randomSampleWithRange(lowerBound, upperBound, seed) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0e7659f443ecd764619c9ff54ba9e5a481f46c9a..8f5984e4a8ce27e2dcd472eee9610f339eaa2f4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -30,6 +30,41 @@ class DataFrameStatSuite extends QueryTest { private def toLetter(i: Int): String = (i + 97).toChar.toString + test("sample with replacement") { + val n = 100 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = true, 0.05, seed = 13), + Seq(5, 10, 52, 73).map(Row(_)) + ) + } + + test("sample without replacement") { + val n = 100 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + checkAnswer( + data.sample(withReplacement = false, 0.05, seed = 13), + Seq(16, 23, 88, 100).map(Row(_)) + ) + } + + test("randomSplit") { + val n = 600 + val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id") + for (seed <- 1 to 5) { + val splits = data.randomSplit(Array[Double](1, 2, 3), seed) + assert(splits.length == 3, "wrong number of splits") + + assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == + data.collect().toList, "incomplete or wrong split") + + val s = splits.map(_.count()) + assert(math.abs(s(0) - 100) < 50) // std = 9.13 + assert(math.abs(s(1) - 200) < 50) // std = 11.55 + assert(math.abs(s(2) - 300) < 50) // std = 12.25 + } + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f9cc6d1f3c250156d3cdc0b0104f0ab4f0f00d6b..0212637a829e5653a0446e3f1b59fc17bfdd6df6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -415,23 +415,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } - test("randomSplit") { - val n = 600 - val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id") - for (seed <- 1 to 5) { - val splits = data.randomSplit(Array[Double](1, 2, 3), seed) - assert(splits.length == 3, "wrong number of splits") - - assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList == - data.collect().toList, "incomplete or wrong split") - - val s = splits.map(_.count()) - assert(math.abs(s(0) - 100) < 50) // std = 9.13 - assert(math.abs(s(1) - 200) < 50) // std = 11.55 - assert(math.abs(s(2) - 300) < 50) // std = 12.25 - } - } - test("describe") { val describeTestData = Seq( ("Bob", 16, 176),