Skip to content
Snippets Groups Projects
Commit 3336c7b1 authored by Feynman Liang's avatar Feynman Liang Committed by Xiangrui Meng
Browse files

[SPARK-8559] [MLLIB] Support Association Rule Generation

Distributed generation of single-consequent association rules from a RDD of frequent itemsets. Tests referenced against `R`'s implementation of A Priori in [arules](http://cran.r-project.org/web/packages/arules/index.html).

Author: Feynman Liang <fliang@databricks.com>

Closes #7005 from feynmanliang/fp-association-rules-distributed and squashes the following commits:

466ced0 [Feynman Liang] Refactor AR generation impl
73c1cff [Feynman Liang] Make rule attributes public, remove numTransactions from FreqItemset
80f63ff [Feynman Liang] Change default confidence and optimize imports
04cf5b5 [Feynman Liang] Code review with @mengxr, add R to tests
0cc1a6a [Feynman Liang] Java compatibility test
f3c14b5 [Feynman Liang] Fix MiMa test
764375e [Feynman Liang] Fix tests
1187307 [Feynman Liang] Almost working tests
b20779b [Feynman Liang] Working implementation
5395c4e [Feynman Liang] Fix imports
2d34405 [Feynman Liang] Partial implementation of distributed ar
83ace4b [Feynman Liang] Local rule generation without pruning complete
69c2c87 [Feynman Liang] Working local implementation, now to parallelize../..
4e1ec9a [Feynman Liang] Pull FreqItemsets out, refactor type param, tests
69ccedc [Feynman Liang] First implementation of association rule generation
parent 70beb808
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm
import scala.reflect.ClassTag
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.mllib.fpm.AssociationRules.Rule
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.rdd.RDD
/**
* :: Experimental ::
*
* Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates
* association rules which have a single item as the consequent.
*/
@Experimental
class AssociationRules private (
private var minConfidence: Double) extends Logging with Serializable {
/**
* Constructs a default instance with default parameters {minConfidence = 0.8}.
*/
def this() = this(0.8)
/**
* Sets the minimal confidence (default: `0.8`).
*/
def setMinConfidence(minConfidence: Double): this.type = {
this.minConfidence = minConfidence
this
}
/**
* Computes the association rules with confidence above [[minConfidence]].
* @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
* @return a [[Set[Rule[Item]]] containing the assocation rules.
*/
def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
// For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
val candidates = freqItemsets.flatMap { itemset =>
val items = itemset.items
items.flatMap { item =>
items.partition(_ == item) match {
case (consequent, antecedent) if !antecedent.isEmpty =>
Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
case _ => None
}
}
}
// Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
.map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
}.filter(_.confidence >= minConfidence)
}
def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = {
val tag = fakeClassTag[Item]
run(freqItemsets.rdd)(tag)
}
}
object AssociationRules {
/**
* :: Experimental ::
*
* An association rule between sets of items.
* @param antecedent hypotheses of the rule
* @param consequent conclusion of the rule
* @tparam Item item type
*/
@Experimental
class Rule[Item] private[mllib] (
val antecedent: Array[Item],
val consequent: Array[Item],
freqUnion: Double,
freqAntecedent: Double) extends Serializable {
def confidence: Double = freqUnion.toDouble / freqAntecedent
require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
val sharedItems = antecedent.toSet.intersect(consequent.toSet)
s"A valid association rule must have disjoint antecedent and " +
s"consequent but ${sharedItems} is present in both."
})
}
}
...@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} ...@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.mllib.fpm.FPGrowth._
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm;
import java.io.Serializable;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import com.google.common.collect.Lists;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
public class JavaAssociationRulesSuite implements Serializable {
private transient JavaSparkContext sc;
@Before
public void setUp() {
sc = new JavaSparkContext("local", "JavaFPGrowth");
}
@After
public void tearDown() {
sc.stop();
sc = null;
}
@Test
public void runAssociationRules() {
@SuppressWarnings("unchecked")
JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Lists.newArrayList(
new FreqItemset<String>(new String[] {"a"}, 15L),
new FreqItemset<String>(new String[] {"b"}, 35L),
new FreqItemset<String>(new String[] {"a", "b"}, 18L)
));
JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
}
}
...@@ -29,7 +29,6 @@ import static org.junit.Assert.*; ...@@ -29,7 +29,6 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
public class JavaFPGrowthSuite implements Serializable { public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc; private transient JavaSparkContext sc;
...@@ -62,10 +61,10 @@ public class JavaFPGrowthSuite implements Serializable { ...@@ -62,10 +61,10 @@ public class JavaFPGrowthSuite implements Serializable {
.setNumPartitions(2) .setNumPartitions(2)
.run(rdd); .run(rdd);
List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
assertEquals(18, freqItemsets.size()); assertEquals(18, freqItemsets.size());
for (FreqItemset<String> itemset: freqItemsets) { for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
// Test return types. // Test return types.
List<String> items = itemset.javaItems(); List<String> items = itemset.javaItems();
long freq = itemset.freq(); long freq = itemset.freq();
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
test("association rules using String type") {
val freqItemsets = sc.parallelize(Seq(
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
(Set("r"), 3L),
(Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
(Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
(Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
(Set("t", "y", "x"), 3L),
(Set("t", "y", "x", "z"), 3L)
).map {
case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
})
val ar = new AssociationRules()
val results1 = ar
.setMinConfidence(0.9)
.run(freqItemsets)
.collect()
/* Verify results using the `R` code:
transactions = as(sapply(
list("r z h k p",
"z y x w v u t s",
"s x o n r",
"x z y m t s q e",
"z",
"x z y r q t p"),
FUN=function(x) strsplit(x," ",fixed=TRUE)),
"transactions")
ars = apriori(transactions,
parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
arsDF = as(ars, "data.frame")
arsDF$support = arsDF$support * length(transactions)
names(arsDF)[names(arsDF) == "support"] = "freq"
> nrow(arsDF)
[1] 23
> sum(arsDF$confidence == 1)
[1] 23
*/
assert(results1.size === 23)
assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
val results2 = ar
.setMinConfidence(0)
.run(freqItemsets)
.collect()
/* Verify results using the `R` code:
ars = apriori(transactions,
parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2))
arsDF = as(ars, "data.frame")
arsDF$support = arsDF$support * length(transactions)
names(arsDF)[names(arsDF) == "support"] = "freq"
nrow(arsDF)
sum(arsDF$confidence == 1)
> nrow(arsDF)
[1] 30
> sum(arsDF$confidence == 1)
[1] 23
*/
assert(results2.size === 30)
assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment