Repository: spark Updated Branches: refs/heads/master a40ffc656 -> df125062c
[SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern mining - PrefixSpan ## What changes were proposed in this pull request? Change `PrefixSpan` into a class with param setter/getters. This address issues mentioned here: https://github.com/apache/spark/pull/20973#discussion_r186931806 ## How was this patch tested? UT. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: WeichenXu <weichen...@databricks.com> Closes #21393 from WeichenXu123/fix_prefix_span. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/df125062 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/df125062 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/df125062 Branch: refs/heads/master Commit: df125062c8dac9fee3328d67dd438a456b7a3b74 Parents: a40ffc6 Author: WeichenXu <weichen...@databricks.com> Authored: Wed May 23 11:00:23 2018 -0700 Committer: Xiangrui Meng <m...@databricks.com> Committed: Wed May 23 11:00:23 2018 -0700 ---------------------------------------------------------------------- .../org/apache/spark/ml/fpm/PrefixSpan.scala | 127 +++++++++++++++---- .../apache/spark/ml/fpm/PrefixSpanSuite.scala | 28 ++-- 2 files changed, 119 insertions(+), 36 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/df125062/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala index 02168fe..41716c6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala @@ -18,6 +18,8 @@ package org.apache.spark.ml.fpm import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan} import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.col @@ -29,13 +31,97 @@ import org.apache.spark.sql.types.{ArrayType, LongType, StructField, StructType} * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns * Efficiently by Prefix-Projected Pattern Growth * (see <a href="http://doi.org/10.1109/ICDE.2001.914830">here</a>). + * This class is not yet an Estimator/Transformer, use `findFrequentSequentialPatterns` method to + * run the PrefixSpan algorithm. * * @see <a href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining">Sequential Pattern Mining * (Wikipedia)</a> */ @Since("2.4.0") @Experimental -object PrefixSpan { +final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params { + + @Since("2.4.0") + def this() = this(Identifiable.randomUID("prefixSpan")) + + /** + * Param for the minimal support level (default: `0.1`). + * Sequential patterns that appear more than (minSupport * size-of-the-dataset) times are + * identified as frequent sequential patterns. + * @group param + */ + @Since("2.4.0") + val minSupport = new DoubleParam(this, "minSupport", "The minimal support level of the " + + "sequential pattern. Sequential pattern that appears more than " + + "(minSupport * size-of-the-dataset)." + + "times will be output.", ParamValidators.gtEq(0.0)) + + /** @group getParam */ + @Since("2.4.0") + def getMinSupport: Double = $(minSupport) + + /** @group setParam */ + @Since("2.4.0") + def setMinSupport(value: Double): this.type = set(minSupport, value) + + /** + * Param for the maximal pattern length (default: `10`). + * @group param + */ + @Since("2.4.0") + val maxPatternLength = new IntParam(this, "maxPatternLength", + "The maximal length of the sequential pattern.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxPatternLength: Int = $(maxPatternLength) + + /** @group setParam */ + @Since("2.4.0") + def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value) + + /** + * Param for the maximum number of items (including delimiters used in the internal storage + * format) allowed in a projected database before local processing (default: `32000000`). + * If a projected database exceeds this size, another iteration of distributed prefix growth + * is run. + * @group param + */ + @Since("2.4.0") + val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize", + "The maximum number of items (including delimiters used in the internal storage format) " + + "allowed in a projected database before local processing. If a projected database exceeds " + + "this size, another iteration of distributed prefix growth is run.", + ParamValidators.gt(0)) + + /** @group getParam */ + @Since("2.4.0") + def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize) + + /** @group setParam */ + @Since("2.4.0") + def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, value) + + /** + * Param for the name of the sequence column in dataset (default "sequence"), rows with + * nulls in this column are ignored. + * @group param + */ + @Since("2.4.0") + val sequenceCol = new Param[String](this, "sequenceCol", "The name of the sequence column in " + + "dataset, rows with nulls in this column are ignored.") + + /** @group getParam */ + @Since("2.4.0") + def getSequenceCol: String = $(sequenceCol) + + /** @group setParam */ + @Since("2.4.0") + def setSequenceCol(value: String): this.type = set(sequenceCol, value) + + setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 32000000, + sequenceCol -> "sequence") /** * :: Experimental :: @@ -43,54 +129,39 @@ object PrefixSpan { * * @param dataset A dataset or a dataframe containing a sequence column which is * {{{Seq[Seq[_]]}}} type - * @param sequenceCol the name of the sequence column in dataset, rows with nulls in this column - * are ignored - * @param minSupport the minimal support level of the sequential pattern, any pattern that - * appears more than (minSupport * size-of-the-dataset) times will be output - * (recommended value: `0.1`). - * @param maxPatternLength the maximal length of the sequential pattern - * (recommended value: `10`). - * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the - * internal storage format) allowed in a projected database before - * local processing. If a projected database exceeds this size, another - * iteration of distributed prefix growth is run - * (recommended value: `32000000`). * @return A `DataFrame` that contains columns of sequence and corresponding frequency. * The schema of it will be: * - `sequence: Seq[Seq[T]]` (T is the item type) * - `freq: Long` */ @Since("2.4.0") - def findFrequentSequentialPatterns( - dataset: Dataset[_], - sequenceCol: String, - minSupport: Double, - maxPatternLength: Int, - maxLocalProjDBSize: Long): DataFrame = { - - val inputType = dataset.schema(sequenceCol).dataType + def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = { + val sequenceColParam = $(sequenceCol) + val inputType = dataset.schema(sequenceColParam).dataType require(inputType.isInstanceOf[ArrayType] && inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType], s"The input column must be ArrayType and the array element type must also be ArrayType, " + s"but got $inputType.") - - val data = dataset.select(sequenceCol) - val sequences = data.where(col(sequenceCol).isNotNull).rdd + val data = dataset.select(sequenceColParam) + val sequences = data.where(col(sequenceColParam).isNotNull).rdd .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray) val mllibPrefixSpan = new mllibPrefixSpan() - .setMinSupport(minSupport) - .setMaxPatternLength(maxPatternLength) - .setMaxLocalProjDBSize(maxLocalProjDBSize) + .setMinSupport($(minSupport)) + .setMaxPatternLength($(maxPatternLength)) + .setMaxLocalProjDBSize($(maxLocalProjDBSize)) val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => Row(f.sequence, f.freq)) val schema = StructType(Seq( - StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = false), + StructField("sequence", dataset.schema(sequenceColParam).dataType, nullable = false), StructField("freq", LongType, nullable = false))) val freqSequences = dataset.sparkSession.createDataFrame(rows, schema) freqSequences } + @Since("2.4.0") + override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra) + } http://git-wip-us.apache.org/repos/asf/spark/blob/df125062/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala index 9e53869..2252151 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala @@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan projections with multiple partial starts") { val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, "sequence", - minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(smallDataset) .as[(Seq[Seq[Int]], Long)].collect() val expected = Array( (Seq(Seq(1)), 1L), @@ -90,8 +93,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan Integer type, variable-size itemsets") { val df = smallTestData.toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -99,8 +105,11 @@ class PrefixSpanSuite extends MLTest { test("PrefixSpan input row with nulls") { val df = (smallTestData :+ null).toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[Int]], Long)].collect() compareResults[Int](smallTestDataExpectedResult, result) @@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest { val df = smallTestData .map(seq => seq.map(itemSet => itemSet.map(intToString))) .toDF("sequence") - val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence", - minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000) + val result = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + .setMaxLocalProjDBSize(32000000) + .findFrequentSequentialPatterns(df) .as[(Seq[Seq[String]], Long)].collect() val expected = smallTestDataExpectedResult.map { case (seq, freq) => --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org