Skip to content
Snippets Groups Projects
Commit 7f6be1f2 authored by zhangjiajin's avatar zhangjiajin Committed by Xiangrui Meng
Browse files

[SPARK-6487] [MLLIB] Add sequential pattern mining algorithm PrefixSpan to Spark MLlib

Add parallel PrefixSpan algorithm and test file.
Support non-temporal sequences.

Author: zhangjiajin <zhangjiajin@huawei.com>
Author: zhang jiajin <zhangjiajin@huawei.com>

Closes #7258 from zhangjiajin/master and squashes the following commits:

ca9c4c8 [zhangjiajin] Modified the code according to the review comments.
574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization.
ba5df34 [zhangjiajin] Fix a Scala style error.
4c60fb3 [zhangjiajin] Fix some Scala style errors.
1dd33ad [zhangjiajin] Modified the code according to the review comments.
89bc368 [zhangjiajin] Fixed a Scala style error.
a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala
951fd42 [zhang jiajin] Delete Prefixspan.scala
575995f [zhangjiajin] Modified the code according to the review comments.
91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file.
parent 9c507577
No related branches found
No related tags found
No related merge requests found
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
/**
*
* :: Experimental ::
*
* Calculate all patterns of a projected database in local.
*/
@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
/**
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
* @param prefix prefix
* @param projectedDatabase the projected dabase
* @return a set of sequential pattern pairs,
* the key of pair is sequential pattern (a list of items),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
prefix: Array[Int],
projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
val frequentPatternAndCounts = frequentPrefixAndCounts
.map(x => (prefix ++ Array(x._1), x._2))
val prefixProjectedDatabases = getPatternAndProjectedDatabase(
prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
if (continueProcess) {
val nextPatterns = prefixProjectedDatabases
.map(x => run(minCount, maxPatternLength, x._1, x._2))
.reduce(_ ++ _)
frequentPatternAndCounts ++ nextPatterns
} else {
frequentPatternAndCounts
}
}
/**
* calculate suffix sequence following a prefix in a sequence
* @param prefix prefix
* @param sequence sequence
* @return suffix sequence
*/
def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
val index = sequence.indexOf(prefix)
if (index == -1) {
Array()
} else {
sequence.drop(index + 1)
}
}
/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences sequences data
* @return array of item and count pair
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: Array[Array[Int]]): Array[(Int, Long)] = {
sequences.flatMap(_.distinct)
.groupBy(x => x)
.mapValues(_.length.toLong)
.filter(_._2 >= minCount)
.toArray
}
/**
* Get the frequent prefixes' projected database.
* @param prePrefix the frequent prefixes' prefix
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPatternAndProjectedDatabase(
prePrefix: Array[Int],
frequentPrefixes: Array[Int],
sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
val filteredProjectedDatabase = sequences
.map(x => x.filter(frequentPrefixes.contains(_)))
frequentPrefixes.map { x =>
val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
(prePrefix ++ Array(x), sub)
}.filter(x => x._2.nonEmpty)
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
*
* :: Experimental ::
*
* A parallel PrefixSpan algorithm to mine sequential pattern.
* The PrefixSpan algorithm is described in
* [[http://doi.org/10.1109/ICDE.2001.914830]].
*
* @param minSupport the minimal support level of the sequential pattern, any pattern appears
* more than (minSupport * size-of-the-dataset) times will be output
* @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
* less than maxPatternLength will be output
*
* @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
* (Wikipedia)]]
*/
@Experimental
class PrefixSpan private (
private var minSupport: Double,
private var maxPatternLength: Int) extends Logging with Serializable {
/**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
*/
def this() = this(0.1, 10)
/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
require(minSupport >= 0 && minSupport <= 1,
"The minimum support value must be between 0 and 1, including 0 and 1.")
this.minSupport = minSupport
this
}
/**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
require(maxPatternLength >= 1,
"The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
/**
* Find the complete set of sequential patterns in the input sequences.
* @param sequences input data set, contains a set of sequences,
* a sequence is an ordered list of elements.
* @return a set of sequential pattern pairs,
* the key of pair is pattern (a list of elements),
* the value of pair is the pattern's count.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
val minCount = getMinCount(sequences)
val lengthOnePatternsAndCounts =
getFreqItemAndCounts(minCount, sequences).collect()
val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
lengthOnePatternsAndCounts.map(_._1), sequences)
val groupedProjectedDatabase = prefixAndProjectedDatabase
.map(x => (x._1.toSeq, x._2))
.groupByKey()
.map(x => (x._1.toArray, x._2.toArray))
val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
val lengthOnePatternsAndCountsRdd =
sequences.sparkContext.parallelize(
lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
allPatterns
}
/**
* Get the minimum count (sequences count * minSupport).
* @param sequences input data set, contains a set of sequences,
* @return minimum count,
*/
private def getMinCount(sequences: RDD[Array[Int]]): Long = {
if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
}
/**
* Generates frequent items by filtering the input data using minimal count level.
* @param minCount the absolute minimum count
* @param sequences original sequences data
* @return array of item and count pair
*/
private def getFreqItemAndCounts(
minCount: Long,
sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
sequences.flatMap(_.distinct.map((_, 1L)))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
}
/**
* Get the frequent prefixes' projected database.
* @param frequentPrefixes frequent prefixes
* @param sequences sequences data
* @return prefixes and projected database
*/
private def getPrefixAndProjectedDatabase(
frequentPrefixes: Array[Int],
sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
val filteredSequences = sequences.map { p =>
p.filter (frequentPrefixes.contains(_) )
}
filteredSequences.flatMap { x =>
frequentPrefixes.map { y =>
val sub = LocalPrefixSpan.getSuffix(y, x)
(Array(y), sub)
}.filter(_._2.nonEmpty)
}
}
/**
* calculate the patterns in local.
* @param minCount the absolute minimum count
* @param data patterns and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
data.flatMap { x =>
LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
}
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
test("PrefixSpan using Integer type") {
/*
library("arulesSequences")
prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
freqItemSeq = cspade(
prefixSpanSeqs,
parameter = list(support =
2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 ))
resSeq = as(freqItemSeq, "data.frame")
resSeq
*/
val sequences = Array(
Array(1, 3, 4, 5),
Array(2, 3, 1),
Array(2, 4, 1),
Array(3, 1, 3, 4, 5),
Array(3, 4, 4, 3),
Array(6, 5, 3))
val rdd = sc.parallelize(sequences, 2).cache()
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
val sortedActualValue = actualValue.sortWith{ (x, y) =>
x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
}
sortedExpectedValue.zip(sortedActualValue)
.map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
.reduce(_&&_)
}
val prefixspan = new PrefixSpan()
.setMinSupport(0.33)
.setMaxPatternLength(50)
val result1 = prefixspan.run(rdd)
val expectedValue1 = Array(
(Array(1), 4L),
(Array(1, 3), 2L),
(Array(1, 3, 4), 2L),
(Array(1, 3, 4, 5), 2L),
(Array(1, 3, 5), 2L),
(Array(1, 4), 2L),
(Array(1, 4, 5), 2L),
(Array(1, 5), 2L),
(Array(2), 2L),
(Array(2, 1), 2L),
(Array(3), 5L),
(Array(3, 1), 2L),
(Array(3, 3), 2L),
(Array(3, 4), 3L),
(Array(3, 4, 5), 2L),
(Array(3, 5), 2L),
(Array(4), 4L),
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue1, result1.collect()))
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
val expectedValue2 = Array(
(Array(1), 4L),
(Array(3), 5L),
(Array(3, 4), 3L),
(Array(4), 4L),
(Array(5), 3L)
)
assert(compareResult(expectedValue2, result2.collect()))
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
val expectedValue3 = Array(
(Array(1), 4L),
(Array(1, 3), 2L),
(Array(1, 4), 2L),
(Array(1, 5), 2L),
(Array(2, 1), 2L),
(Array(2), 2L),
(Array(3), 5L),
(Array(3, 1), 2L),
(Array(3, 3), 2L),
(Array(3, 4), 3L),
(Array(3, 5), 2L),
(Array(4), 4L),
(Array(4, 5), 2L),
(Array(5), 3L)
)
assert(compareResult(expectedValue3, result3.collect()))
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment