diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala new file mode 100644 index 0000000000000000000000000000000000000000..4a0f842f3338d2646579c9e4c005f04219afce51 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.AssociationRules.Rule +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * + * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates + * association rules which have a single item as the consequent. + */ +@Experimental +class AssociationRules private ( + private var minConfidence: Double) extends Logging with Serializable { + + /** + * Constructs a default instance with default parameters {minConfidence = 0.8}. + */ + def this() = this(0.8) + + /** + * Sets the minimal confidence (default: `0.8`). + */ + def setMinConfidence(minConfidence: Double): this.type = { + this.minConfidence = minConfidence + this + } + + /** + * Computes the association rules with confidence above [[minConfidence]]. + * @param freqItemsets frequent itemset model obtained from [[FPGrowth]] + * @return a [[Set[Rule[Item]]] containing the assocation rules. + */ + def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = { + // For candidate rule X => Y, generate (X, (Y, freq(X union Y))) + val candidates = freqItemsets.flatMap { itemset => + val items = itemset.items + items.flatMap { item => + items.partition(_ == item) match { + case (consequent, antecedent) if !antecedent.isEmpty => + Some((antecedent.toSeq, (consequent.toSeq, itemset.freq))) + case _ => None + } + } + } + + // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence + candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq))) + .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) => + new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent) + }.filter(_.confidence >= minConfidence) + } + + def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = { + val tag = fakeClassTag[Item] + run(freqItemsets.rdd)(tag) + } +} + +object AssociationRules { + + /** + * :: Experimental :: + * + * An association rule between sets of items. + * @param antecedent hypotheses of the rule + * @param consequent conclusion of the rule + * @tparam Item item type + */ + @Experimental + class Rule[Item] private[mllib] ( + val antecedent: Array[Item], + val consequent: Array[Item], + freqUnion: Double, + freqAntecedent: Double) extends Serializable { + + def confidence: Double = freqUnion.toDouble / freqAntecedent + + require(antecedent.toSet.intersect(consequent.toSet).isEmpty, { + val sharedItems = antecedent.toSet.intersect(consequent.toSet) + s"A valid association rule must have disjoint antecedent and " + + s"consequent but ${sharedItems} is present in both." + }) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index efa8459d3cdba2ac6452d80c1b3c8272615bbd3d..0da59e812d5f9260051604a7103f6b4d271e3cec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset +import org.apache.spark.mllib.fpm.FPGrowth._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java new file mode 100644 index 0000000000000000000000000000000000000000..b3815ae6039c0d137695530ce2112686c72f9b05 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm; + +import java.io.Serializable; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; + + +public class JavaAssociationRulesSuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaFPGrowth"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void runAssociationRules() { + + @SuppressWarnings("unchecked") + JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Lists.newArrayList( + new FreqItemset<String>(new String[] {"a"}, 15L), + new FreqItemset<String>(new String[] {"b"}, 35L), + new FreqItemset<String>(new String[] {"a", "b"}, 18L) + )); + + JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets); + } +} + diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index bd0edf2b9ea62f7ce236493a84b2925f34951ced..9ce2c52dca8b6c8986656e72c74c518887a16d13 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -29,7 +29,6 @@ import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; public class JavaFPGrowthSuite implements Serializable { private transient JavaSparkContext sc; @@ -62,10 +61,10 @@ public class JavaFPGrowthSuite implements Serializable { .setNumPartitions(2) .run(rdd); - List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); + List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); assertEquals(18, freqItemsets.size()); - for (FreqItemset<String> itemset: freqItemsets) { + for (FPGrowth.FreqItemset<String> itemset: freqItemsets) { // Test return types. List<String> items = itemset.javaItems(); long freq = itemset.freq(); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..77a2773c36f5661336b2f03085eb0f2cfc908268 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("association rules using String type") { + val freqItemsets = sc.parallelize(Seq( + (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), + (Set("r"), 3L), + (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L), + (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L), + (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L), + (Set("t", "y", "x"), 3L), + (Set("t", "y", "x", "z"), 3L) + ).map { + case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq) + }) + + val ar = new AssociationRules() + + val results1 = ar + .setMinConfidence(0.9) + .run(freqItemsets) + .collect() + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + ars = apriori(transactions, + parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + > nrow(arsDF) + [1] 23 + > sum(arsDF$confidence == 1) + [1] 23 + */ + assert(results1.size === 23) + assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + + val results2 = ar + .setMinConfidence(0) + .run(freqItemsets) + .collect() + + /* Verify results using the `R` code: + ars = apriori(transactions, + parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + nrow(arsDF) + sum(arsDF$confidence == 1) + > nrow(arsDF) + [1] 30 + > sum(arsDF$confidence == 1) + [1] 23 + */ + assert(results2.size === 30) + assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + } +} +