Skip to content
Snippets Groups Projects
Commit 98e69467 authored by Burak Yavuz's avatar Burak Yavuz Committed by Xiangrui Meng
Browse files

[SPARK-9615] [SPARK-9616] [SQL] [MLLIB] Bugs related to FrequentItems when...

[SPARK-9615] [SPARK-9616] [SQL] [MLLIB] Bugs related to FrequentItems when merging and with Tungsten

In short:
1- FrequentItems should not use the InternalRow representation, because the keys in the map get messed up. For example, every key in the Map correspond to the very last element observed in the partition, when the elements are strings.

2- Merging two partitions had a bug:

**Existing behavior with size 3**
Partition A -> Map(1 -> 3, 2 -> 3, 3 -> 4)
Partition B -> Map(4 -> 25)
Result -> Map()

**Correct Behavior:**
Partition A -> Map(1 -> 3, 2 -> 3, 3 -> 4)
Partition B -> Map(4 -> 25)
Result -> Map(3 -> 1, 4 -> 22)

cc mengxr rxin JoshRosen

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #7945 from brkyvz/freq-fix and squashes the following commits:

07fa001 [Burak Yavuz] address 2
1dc61a8 [Burak Yavuz] address 1
506753e [Burak Yavuz] fixed and added reg test
47bfd50 [Burak Yavuz] pushing
parent 076ec056
No related branches found
No related tags found
No related merge requests found
......@@ -20,17 +20,15 @@ package org.apache.spark.sql.execution.stat
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.{Row, Column, DataFrame}
private[sql] object FrequentItems extends Logging {
/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
private class FreqItemCounter(size: Int) extends Serializable {
val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]
/**
* Add a new example to the counts if it exists, otherwise deduct the count
* from existing items.
......@@ -42,9 +40,15 @@ private[sql] object FrequentItems extends Logging {
if (baseMap.size < size) {
baseMap += key -> count
} else {
// TODO: Make this more efficient... A flatMap?
baseMap.retain((k, v) => v > count)
baseMap.transform((k, v) => v - count)
val minCount = baseMap.values.min
val remainder = count - minCount
if (remainder >= 0) {
baseMap += key -> count // something will get kicked out, so we can add this
baseMap.retain((k, v) => v > minCount)
baseMap.transform((k, v) => v - minCount)
} else {
baseMap.transform((k, v) => v - count)
}
}
}
this
......@@ -90,12 +94,12 @@ private[sql] object FrequentItems extends Logging {
(name, originalSchema.fields(index).dataType)
}.toArray
val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)(
val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
val thisMap = counts(i)
val key = row.get(i, colInfo(i)._2)
val key = row.get(i)
thisMap.add(key, 1L)
i += 1
}
......@@ -110,13 +114,13 @@ private[sql] object FrequentItems extends Logging {
baseCounts
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toArray).map(new GenericArrayData(_))
val resultRow = InternalRow(justItems : _*)
val justItems = freqItems.map(m => m.baseMap.keys.toArray)
val resultRow = Row(justItems : _*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
}
}
......@@ -123,12 +123,30 @@ class DataFrameStatSuite extends QueryTest {
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
items.getSeq[Int](0) should contain (1)
items.getSeq[String](1) should contain (toLetter(1))
assert(items.getSeq[Int](0).contains(1))
assert(items.getSeq[String](1).contains(toLetter(1)))
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
assert(items2.getSeq[Double](0).contains(-1.0))
}
test("Frequent Items 2") {
val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4)
// this is a regression test, where when merging partitions, we omitted values with higher
// counts than those that existed in the map when the map was full. This test should also fail
// if anything like SPARK-9614 is observed once again
val df = rows.mapPartitionsWithIndex { (idx, iter) =>
if (idx == 3) { // must come from one of the later merges, therefore higher partition index
Iterator("3", "3", "3", "3", "3")
} else {
Iterator("0", "1", "2", "3", "4")
}
}.toDF("a")
val results = df.stat.freqItems(Array("a"), 0.25)
val items = results.collect().head.getSeq[String](0)
assert(items.contains("3"))
assert(items.length === 1)
}
test("sampleBy") {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment