From 3336c7b148ad543d1f9b64ca2b559ea04930f5be Mon Sep 17 00:00:00 2001 From: Feynman Liang <fliang@databricks.com> Date: Tue, 7 Jul 2015 11:34:30 -0700 Subject: [PATCH] [SPARK-8559] [MLLIB] Support Association Rule Generation Distributed generation of single-consequent association rules from a RDD of frequent itemsets. Tests referenced against `R`'s implementation of A Priori in [arules](http://cran.r-project.org/web/packages/arules/index.html). Author: Feynman Liang <fliang@databricks.com> Closes #7005 from feynmanliang/fp-association-rules-distributed and squashes the following commits: 466ced0 [Feynman Liang] Refactor AR generation impl 73c1cff [Feynman Liang] Make rule attributes public, remove numTransactions from FreqItemset 80f63ff [Feynman Liang] Change default confidence and optimize imports 04cf5b5 [Feynman Liang] Code review with @mengxr, add R to tests 0cc1a6a [Feynman Liang] Java compatibility test f3c14b5 [Feynman Liang] Fix MiMa test 764375e [Feynman Liang] Fix tests 1187307 [Feynman Liang] Almost working tests b20779b [Feynman Liang] Working implementation 5395c4e [Feynman Liang] Fix imports 2d34405 [Feynman Liang] Partial implementation of distributed ar 83ace4b [Feynman Liang] Local rule generation without pruning complete 69c2c87 [Feynman Liang] Working local implementation, now to parallelize../.. 4e1ec9a [Feynman Liang] Pull FreqItemsets out, refactor type param, tests 69ccedc [Feynman Liang] First implementation of association rule generation --- .../spark/mllib/fpm/AssociationRules.scala | 108 ++++++++++++++++++ .../org/apache/spark/mllib/fpm/FPGrowth.scala | 2 +- .../mllib/fpm/JavaAssociationRulesSuite.java | 58 ++++++++++ .../spark/mllib/fpm/JavaFPGrowthSuite.java | 5 +- .../mllib/fpm/AssociationRulesSuite.scala | 89 +++++++++++++++ 5 files changed, 258 insertions(+), 4 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala create mode 100644 mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala new file mode 100644 index 0000000000..4a0f842f33 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.AssociationRules.Rule +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates + * association rules which have a single item as the consequent. + */ +@Experimental +class AssociationRules private ( + private var minConfidence: Double) extends Logging with Serializable { + + /** + * Constructs a default instance with default parameters {minConfidence = 0.8}. + */ + def this() = this(0.8) + + /** + * Sets the minimal confidence (default: `0.8`). + */ + def setMinConfidence(minConfidence: Double): this.type = { + this.minConfidence = minConfidence + this + } + + /** + * Computes the association rules with confidence above [[minConfidence]]. + * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] + * @return a [[Set[Rule[Item]]] containing the assocation rules. + */ + def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { + // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) + val candidates = freqItemsets.flatMap { itemset => + val items = itemset.items + items.flatMap { item => + items.partition(_ == item) match { + case (consequent, antecedent) if !antecedent.isEmpty => + Some((antecedent.toSeq, (consequent.toSeq, itemset.freq))) + case _ => None + } + } + } + + // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence + candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) + .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => + new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) + }.filter(_.confidence >= minConfidence) + } + + def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { + val tag = fakeClassTag[Item] + run(freqItemsets.rdd)(tag) + } +} + +object AssociationRules { + + /** + * :: Experimental :: + * + * An association rule between sets of items. + * @param antecedent hypotheses of the rule + * @param consequent conclusion of the rule + * @tparam Item item type + */ + @Experimental + class Rule[Item] private[mllib] ( + val antecedent: Array[Item], + val consequent: Array[Item], + freqUnion: Double, + freqAntecedent: Double) extends Serializable { + + def confidence: Double = freqUnion.toDouble / freqAntecedent + + require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { + val sharedItems = antecedent.toSet.intersect(consequent.toSet) + s"A valid association rule must have disjoint antecedent and " + + s"consequent but ${sharedItems} is present in both." + }) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index efa8459d3c..0da59e812d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.mllib.fpm.FPGrowth._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java new file mode 100644 index 0000000000..b3815ae603 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + + +public class JavaAssociationRulesSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runAssociationRules() { + + @SuppressWarnings("unchecked") + JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Lists.newArrayList( + new FreqItemset<String>(new String[] {"a"}, 15L), + new FreqItemset<String>(new String[] {"b"}, 35L), + new FreqItemset<String>(new String[] {"a", "b"}, 18L) + )); + + JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets); + } +} + diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index bd0edf2b9e..9ce2c52dca 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -29,7 +29,6 @@ import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; public class JavaFPGrowthSuite implements Serializable { private transient JavaSparkContext sc; @@ -62,10 +61,10 @@ public class JavaFPGrowthSuite implements Serializable { .setNumPartitions(2) .run(rdd); - List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); + List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); assertEquals(18, freqItemsets.size()); - for (FreqItemset<String> itemset: freqItemsets) { + for (FPGrowth.FreqItemset<String> itemset: freqItemsets) { // Test return types. List<String> items = itemset.javaItems(); long freq = itemset.freq(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala new file mode 100644 index 0000000000..77a2773c36 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("association rules using String type") { + val freqItemsets = sc.parallelize(Seq( + (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), + (Set("r"), 3L), + (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), + (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), + (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), + (Set("t", "y", "x"), 3L), + (Set("t", "y", "x", "z"), 3L) + ).map { + case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq) + }) + + val ar = new AssociationRules() + + val results1 = ar + .setMinConfidence(0.9) + .run(freqItemsets) + .collect() + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + ars = apriori(transactions, + parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + > nrow(arsDF) + [1] 23 + > sum(arsDF$confidence == 1) + [1] 23 + */ + assert(results1.size === 23) + assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + + val results2 = ar + .setMinConfidence(0) + .run(freqItemsets) + .collect() + + /* Verify results using the `R` code: + ars = apriori(transactions, + parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + nrow(arsDF) + sum(arsDF$confidence == 1) + > nrow(arsDF) + [1] 30 + > sum(arsDF$confidence == 1) + [1] 23 + */ + assert(results2.size === 30) + assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + } +} + -- GitLab