Skip to content
Snippets Groups Projects
Commit 149b3ee2 authored by Burak Yavuz's avatar Burak Yavuz Committed by Reynold Xin
Browse files

[SPARK-7242][SQL][MLLIB] Frequent items for DataFrames

Finding frequent items with possibly false positives, using the algorithm described in `http://www.cs.umd.edu/~samir/498/karp.pdf`.
public API under:
```
df.stat.freqItems(cols: Array[String], support: Double = 0.001): DataFrame
```

The output is a local DataFrame having the input column names with `-freqItems` appended to it. This is a single pass algorithm that may return false positives, but no false negatives.

cc mengxr rxin

Let's get the implementations in, I can add python API in a follow up PR.

Author: Burak Yavuz <brkyvz@gmail.com>

Closes #5799 from brkyvz/freq-items and squashes the following commits:

a6ec82c [Burak Yavuz] addressed comments v?
39b1bba [Burak Yavuz] removed toSeq
0915e23 [Burak Yavuz] addressed comments v2.1
3a5c177 [Burak Yavuz] addressed comments v2.0
482e741 [Burak Yavuz] removed old import
38e784d [Burak Yavuz] addressed comments v1.0
8279d4d [Burak Yavuz] added default value for support
3d82168 [Burak Yavuz] made base implementation
parent 1c3e402e
No related branches found
No related tags found
No related merge requests found
......@@ -330,6 +330,17 @@ class DataFrame private[sql](
*/
def na: DataFrameNaFunctions = new DataFrameNaFunctions(this)
/**
* Returns a [[DataFrameStatFunctions]] for working statistic functions support.
* {{{
* // Finding frequent items in column with name 'a'.
* df.stat.freqItems(Seq("a"))
* }}}
*
* @group dfops
*/
def stat: DataFrameStatFunctions = new DataFrameStatFunctions(this)
/**
* Cartesian join with another [[DataFrame]].
*
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.execution.stat.FrequentItems
/**
* :: Experimental ::
* Statistic functions for [[DataFrame]]s.
*/
@Experimental
final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
*
* @param cols the names of the columns to search frequent items in.
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
* than 1e-4.
* @return A Local DataFrame with the Array of frequent items for each column.
*/
def freqItems(cols: Array[String], support: Double): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, support)
}
/**
* Runs `freqItems` with a default `support` of 1%.
*
* @param cols the names of the columns to search frequent items in.
* @return A Local DataFrame with the Array of frequent items for each column.
*/
def freqItems(cols: Array[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
/**
* Python friendly implementation for `freqItems`
*/
def freqItems(cols: List[String], support: Double): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, support)
}
/**
* Python friendly implementation for `freqItems` with a default `support` of 1%.
*/
def freqItems(cols: List[String]): DataFrame = {
FrequentItems.singlePassFreqItems(df, cols, 0.01)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.stat
import scala.collection.mutable.{Map => MutableMap}
import org.apache.spark.Logging
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}
private[sql] object FrequentItems extends Logging {
/** A helper class wrapping `MutableMap[Any, Long]` for simplicity. */
private class FreqItemCounter(size: Int) extends Serializable {
val baseMap: MutableMap[Any, Long] = MutableMap.empty[Any, Long]
/**
* Add a new example to the counts if it exists, otherwise deduct the count
* from existing items.
*/
def add(key: Any, count: Long): this.type = {
if (baseMap.contains(key)) {
baseMap(key) += count
} else {
if (baseMap.size < size) {
baseMap += key -> count
} else {
// TODO: Make this more efficient... A flatMap?
baseMap.retain((k, v) => v > count)
baseMap.transform((k, v) => v - count)
}
}
this
}
/**
* Merge two maps of counts.
* @param other The map containing the counts for that partition
*/
def merge(other: FreqItemCounter): this.type = {
other.baseMap.foreach { case (k, v) =>
add(k, v)
}
this
}
}
/**
* Finding frequent items for columns, possibly with false positives. Using the
* frequent element count algorithm described in
* [[http://dx.doi.org/10.1145/762471.762473, proposed by Karp, Schenker, and Papadimitriou]].
* The `support` should be greater than 1e-4.
* For Internal use only.
*
* @param df The input DataFrame
* @param cols the names of the columns to search frequent items in
* @param support The minimum frequency for an item to be considered `frequent`. Should be greater
* than 1e-4.
* @return A Local DataFrame with the Array of frequent items for each column.
*/
private[sql] def singlePassFreqItems(
df: DataFrame,
cols: Seq[String],
support: Double): DataFrame = {
require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
val numCols = cols.length
// number of max items to keep counts for
val sizeOfMap = (1 / support).toInt
val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
val originalSchema = df.schema
val colInfo = cols.map { name =>
val index = originalSchema.fieldIndex(name)
(name, originalSchema.fields(index).dataType)
}
val freqItems = df.select(cols.map(Column(_)):_*).rdd.aggregate(countMaps)(
seqOp = (counts, row) => {
var i = 0
while (i < numCols) {
val thisMap = counts(i)
val key = row.get(i)
thisMap.add(key, 1L)
i += 1
}
counts
},
combOp = (baseCounts, counts) => {
var i = 0
while (i < numCols) {
baseCounts(i).merge(counts(i))
i += 1
}
baseCounts
}
)
val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
val resultRow = Row(justItems:_*)
// append frequent Items to the column name for easy debugging
val outputCols = colInfo.map { v =>
StructField(v._1 + "_freqItems", ArrayType(v._2, false))
}
val schema = StructType(outputCols).toAttributes
new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
}
}
......@@ -22,10 +22,7 @@ import com.google.common.primitives.Ints;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.TestData$;
import org.apache.spark.sql.*;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
......@@ -178,5 +175,12 @@ public class JavaDataFrameSuite {
Assert.assertEquals(bean.getD().get(i), d.apply(i));
}
}
@Test
public void testFrequentItems() {
DataFrame df = context.table("testData2");
String[] cols = new String[]{"a"};
DataFrame results = df.stat().freqItems(cols, 0.2);
Assert.assertTrue(results.collect()[0].getSeq(0).contains(1));
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.scalatest.FunSuite
import org.scalatest.Matchers._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.implicits._
class DataFrameStatSuite extends FunSuite {
val sqlCtx = TestSQLContext
test("Frequent Items") {
def toLetter(i: Int): String = (i + 96).toChar.toString
val rows = Array.tabulate(1000) { i =>
if (i % 3 == 0) (1, toLetter(1), -1.0) else (i, toLetter(i), i * -1.0)
}
val df = sqlCtx.sparkContext.parallelize(rows).toDF("numbers", "letters", "negDoubles")
val results = df.stat.freqItems(Array("numbers", "letters"), 0.1)
val items = results.collect().head
items.getSeq[Int](0) should contain (1)
items.getSeq[String](1) should contain (toLetter(1))
val singleColResults = df.stat.freqItems(Array("negDoubles"), 0.1)
val items2 = singleColResults.collect().head
items2.getSeq[Double](0) should contain (-1.0)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment