Github user feynmanliang commented on a diff in the pull request:
https://github.com/apache/spark/pull/7937#discussion_r36242946
--- Diff:
mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala ---
@@ -22,85 +22,89 @@ import scala.collection.mutable
import org.apache.spark.Logging
/**
- * Calculate all patterns of a projected database in local.
+ * Calculate all patterns of a projected database in local mode.
+ *
+ * @param minCount minimal count for a frequent pattern
+ * @param maxPatternLength max pattern length for a frequent pattern
*/
-private[fpm] object LocalPrefixSpan extends Logging with Serializable {
- import PrefixSpan._
+private[fpm] class LocalPrefixSpan(
+ val minCount: Long,
+ val maxPatternLength: Int) extends Logging with Serializable {
+ import PrefixSpan.Postfix
+ import LocalPrefixSpan.ReversedPrefix
+
/**
- * Calculate all patterns of a projected database.
- * @param minCount minimum count
- * @param maxPatternLength maximum pattern length
- * @param prefixes prefixes in reversed order
- * @param database the projected database
- * @return a set of sequential pattern pairs,
- * the key of pair is sequential pattern (a list of items in
reversed order),
- * the value of pair is the pattern's count.
+ * Generates frequent patterns on the input array of postfixes.
+ * @param postfixes an array of postfixes
+ * @return an iterator of (frequent pattern, count)
*/
- def run(
- minCount: Long,
- maxPatternLength: Int,
- prefixes: List[Set[Int]],
- database: Iterable[List[Set[Int]]]): Iterator[(List[Set[Int]],
Long)] = {
- if (prefixes.length == maxPatternLength || database.isEmpty) {
- return Iterator.empty
- }
- val freqItemSetsAndCounts = getFreqItemAndCounts(minCount, database)
- val freqItems = freqItemSetsAndCounts.keys.flatten.toSet
- val filteredDatabase = database.map { suffix =>
- suffix
- .map(item => freqItems.intersect(item))
- .filter(_.nonEmpty)
- }
- freqItemSetsAndCounts.iterator.flatMap { case (item, count) =>
- val newPrefixes = item :: prefixes
- val newProjected = project(filteredDatabase, item)
- Iterator.single((newPrefixes, count)) ++
- run(minCount, maxPatternLength, newPrefixes, newProjected)
+ def run(postfixes: Array[Postfix]): Iterator[(Array[Int], Long)] = {
+ genFreqPatterns(ReversedPrefix.empty, postfixes).map { case (prefix,
count) =>
+ (prefix.toSequence, count)
}
}
/**
- * Calculate suffix sequence immediately after the first occurrence of
an item.
- * @param item itemset to get suffix after
- * @param sequence sequence to extract suffix from
- * @return suffix sequence
+ * Recursively generates frequent patterns.
+ * @param prefix current prefix
+ * @param postfixes projected postfixes w.r.t. the prefix
+ * @return an iterator of (prefix, count)
*/
- def getSuffix(item: Set[Int], sequence: List[Set[Int]]): List[Set[Int]]
= {
- val itemsetSeq = sequence
- val index = itemsetSeq.indexWhere(item.subsetOf(_))
- if (index == -1) {
- List()
- } else {
- itemsetSeq.drop(index + 1)
+ private def genFreqPatterns(
+ prefix: ReversedPrefix,
+ postfixes: Array[Postfix]): Iterator[(ReversedPrefix, Long)] = {
+ if (maxPatternLength == prefix.length || postfixes.length < minCount) {
+ return Iterator.empty
+ }
+ // find frequent items
+ val counts = mutable.Map.empty[Int, Long].withDefaultValue(0)
+ postfixes.foreach { postfix =>
+ postfix.genPrefixItems.foreach { case (x, _) =>
+ counts(x) = counts(x) + 1L
--- End diff --
`+=`
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]