Skip to content
Snippets Groups Projects
Commit e9c36938 authored by Reynold Xin's avatar Reynold Xin
Browse files

[SPARK-9752][SQL] Support UnsafeRow in Sample operator.

In order for this to work, I had to disable gap sampling.

Author: Reynold Xin <rxin@databricks.com>

Closes #8040 from rxin/SPARK-9752 and squashes the following commits:

f9e248c [Reynold Xin] Fix the test case for real this time.
adbccb3 [Reynold Xin] Fixed test case.
589fb23 [Reynold Xin] Merge branch 'SPARK-9752' of github.com:rxin/spark into SPARK-9752
55ccddc [Reynold Xin] Fixed core test.
78fa895 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator.
c9e7112 [Reynold Xin] [SPARK-9752][SQL] Support UnsafeRow in Sample operator.
parent 3ca995b7
No related branches found
No related tags found
No related merge requests found
......@@ -176,10 +176,15 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
* A sampler for sampling with replacement, based on values drawn from Poisson distribution.
*
* @param fraction the sampling fraction (with replacement)
* @param useGapSamplingIfPossible if true, use gap sampling when sampling ratio is low.
* @tparam T item type
*/
@DeveloperApi
class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
class PoissonSampler[T: ClassTag](
fraction: Double,
useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
def this(fraction: Double) = this(fraction, useGapSamplingIfPossible = true)
/** Epsilon slop to avoid failure from floating point jitter. */
require(
......@@ -199,17 +204,18 @@ class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T]
override def sample(items: Iterator[T]): Iterator[T] = {
if (fraction <= 0.0) {
Iterator.empty
} else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
} else if (useGapSamplingIfPossible &&
fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
} else {
items.flatMap { item => {
items.flatMap { item =>
val count = rng.sample()
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
}}
}
}
}
override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction)
override def clone: PoissonSampler[T] = new PoissonSampler[T](fraction, useGapSamplingIfPossible)
}
......
......@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.{RDD, ShuffledRDD}
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
......@@ -30,6 +30,7 @@ import org.apache.spark.sql.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.collection.ExternalSorter
import org.apache.spark.util.collection.unsafe.sort.PrefixComparator
import org.apache.spark.util.random.PoissonSampler
import org.apache.spark.util.{CompletionIterator, MutablePair}
import org.apache.spark.{HashPartitioner, SparkEnv}
......@@ -130,12 +131,21 @@ case class Sample(
{
override def output: Seq[Attribute] = child.output
// TODO: How to pick seed?
override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true
override def canProcessSafeRows: Boolean = true
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
child.execute().map(_.copy()).sample(withReplacement, upperBound - lowerBound, seed)
// Disable gap sampling since the gap sampling method buffers two rows internally,
// requiring us to copy the row, which is more expensive than the random number generator.
new PartitionwiseSampledRDD[InternalRow, InternalRow](
child.execute(),
new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
preservesPartitioning = true,
seed)
} else {
child.execute().map(_.copy()).randomSampleWithRange(lowerBound, upperBound, seed)
child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
}
}
}
......
......@@ -30,6 +30,41 @@ class DataFrameStatSuite extends QueryTest {
private def toLetter(i: Int): String = (i + 97).toChar.toString
test("sample with replacement") {
val n = 100
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = true, 0.05, seed = 13),
Seq(5, 10, 52, 73).map(Row(_))
)
}
test("sample without replacement") {
val n = 100
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = false, 0.05, seed = 13),
Seq(16, 23, 88, 100).map(Row(_))
)
}
test("randomSplit") {
val n = 600
val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")
val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}
test("pearson correlation") {
val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
val corr1 = df.stat.corr("a", "b", "pearson")
......
......@@ -415,23 +415,6 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}
test("randomSplit") {
val n = 600
val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
assert(splits.reduce((a, b) => a.unionAll(b)).sort("id").collect().toList ==
data.collect().toList, "incomplete or wrong split")
val s = splits.map(_.count())
assert(math.abs(s(0) - 100) < 50) // std = 9.13
assert(math.abs(s(1) - 200) < 50) // std = 11.55
assert(math.abs(s(2) - 300) < 50) // std = 12.25
}
}
test("describe") {
val describeTestData = Seq(
("Bob", 16, 176),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment