Repository: spark
Updated Branches:
  refs/heads/master a40ffc656 -> df125062c


[SPARK-20114][ML][FOLLOW-UP] spark.ml parity for sequential pattern mining - 
PrefixSpan

## What changes were proposed in this pull request?

Change `PrefixSpan` into a class with param setter/getters.
This address issues mentioned here:
https://github.com/apache/spark/pull/20973#discussion_r186931806

## How was this patch tested?

UT.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: WeichenXu <weichen...@databricks.com>

Closes #21393 from WeichenXu123/fix_prefix_span.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/df125062
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/df125062
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/df125062

Branch: refs/heads/master
Commit: df125062c8dac9fee3328d67dd438a456b7a3b74
Parents: a40ffc6
Author: WeichenXu <weichen...@databricks.com>
Authored: Wed May 23 11:00:23 2018 -0700
Committer: Xiangrui Meng <m...@databricks.com>
Committed: Wed May 23 11:00:23 2018 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/fpm/PrefixSpan.scala    | 127 +++++++++++++++----
 .../apache/spark/ml/fpm/PrefixSpanSuite.scala   |  28 ++--
 2 files changed, 119 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/df125062/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index 02168fe..41716c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -18,6 +18,8 @@
 package org.apache.spark.ml.fpm
 
 import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.col
@@ -29,13 +31,97 @@ import org.apache.spark.sql.types.{ArrayType, LongType, 
StructField, StructType}
  * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining 
Sequential Patterns
  * Efficiently by Prefix-Projected Pattern Growth
  * (see <a href="http://doi.org/10.1109/ICDE.2001.914830";>here</a>).
+ * This class is not yet an Estimator/Transformer, use 
`findFrequentSequentialPatterns` method to
+ * run the PrefixSpan algorithm.
  *
  * @see <a 
href="https://en.wikipedia.org/wiki/Sequential_Pattern_Mining";>Sequential 
Pattern Mining
  * (Wikipedia)</a>
  */
 @Since("2.4.0")
 @Experimental
-object PrefixSpan {
+final class PrefixSpan(@Since("2.4.0") override val uid: String) extends 
Params {
+
+  @Since("2.4.0")
+  def this() = this(Identifiable.randomUID("prefixSpan"))
+
+  /**
+   * Param for the minimal support level (default: `0.1`).
+   * Sequential patterns that appear more than (minSupport * 
size-of-the-dataset) times are
+   * identified as frequent sequential patterns.
+   * @group param
+   */
+  @Since("2.4.0")
+  val minSupport = new DoubleParam(this, "minSupport", "The minimal support 
level of the " +
+    "sequential pattern. Sequential pattern that appears more than " +
+    "(minSupport * size-of-the-dataset)." +
+    "times will be output.", ParamValidators.gtEq(0.0))
+
+  /** @group getParam */
+  @Since("2.4.0")
+  def getMinSupport: Double = $(minSupport)
+
+  /** @group setParam */
+  @Since("2.4.0")
+  def setMinSupport(value: Double): this.type = set(minSupport, value)
+
+  /**
+   * Param for the maximal pattern length (default: `10`).
+   * @group param
+   */
+  @Since("2.4.0")
+  val maxPatternLength = new IntParam(this, "maxPatternLength",
+    "The maximal length of the sequential pattern.",
+    ParamValidators.gt(0))
+
+  /** @group getParam */
+  @Since("2.4.0")
+  def getMaxPatternLength: Int = $(maxPatternLength)
+
+  /** @group setParam */
+  @Since("2.4.0")
+  def setMaxPatternLength(value: Int): this.type = set(maxPatternLength, value)
+
+  /**
+   * Param for the maximum number of items (including delimiters used in the 
internal storage
+   * format) allowed in a projected database before local processing (default: 
`32000000`).
+   * If a projected database exceeds this size, another iteration of 
distributed prefix growth
+   * is run.
+   * @group param
+   */
+  @Since("2.4.0")
+  val maxLocalProjDBSize = new LongParam(this, "maxLocalProjDBSize",
+    "The maximum number of items (including delimiters used in the internal 
storage format) " +
+    "allowed in a projected database before local processing. If a projected 
database exceeds " +
+    "this size, another iteration of distributed prefix growth is run.",
+    ParamValidators.gt(0))
+
+  /** @group getParam */
+  @Since("2.4.0")
+  def getMaxLocalProjDBSize: Long = $(maxLocalProjDBSize)
+
+  /** @group setParam */
+  @Since("2.4.0")
+  def setMaxLocalProjDBSize(value: Long): this.type = set(maxLocalProjDBSize, 
value)
+
+  /**
+   * Param for the name of the sequence column in dataset (default 
"sequence"), rows with
+   * nulls in this column are ignored.
+   * @group param
+   */
+  @Since("2.4.0")
+  val sequenceCol = new Param[String](this, "sequenceCol", "The name of the 
sequence column in " +
+    "dataset, rows with nulls in this column are ignored.")
+
+  /** @group getParam */
+  @Since("2.4.0")
+  def getSequenceCol: String = $(sequenceCol)
+
+  /** @group setParam */
+  @Since("2.4.0")
+  def setSequenceCol(value: String): this.type = set(sequenceCol, value)
+
+  setDefault(minSupport -> 0.1, maxPatternLength -> 10, maxLocalProjDBSize -> 
32000000,
+    sequenceCol -> "sequence")
 
   /**
    * :: Experimental ::
@@ -43,54 +129,39 @@ object PrefixSpan {
    *
    * @param dataset A dataset or a dataframe containing a sequence column 
which is
    *                {{{Seq[Seq[_]]}}} type
-   * @param sequenceCol the name of the sequence column in dataset, rows with 
nulls in this column
-   *                    are ignored
-   * @param minSupport the minimal support level of the sequential pattern, 
any pattern that
-   *                   appears more than (minSupport * size-of-the-dataset) 
times will be output
-   *                  (recommended value: `0.1`).
-   * @param maxPatternLength the maximal length of the sequential pattern
-   *                         (recommended value: `10`).
-   * @param maxLocalProjDBSize The maximum number of items (including 
delimiters used in the
-   *                           internal storage format) allowed in a projected 
database before
-   *                           local processing. If a projected database 
exceeds this size, another
-   *                           iteration of distributed prefix growth is run
-   *                           (recommended value: `32000000`).
    * @return A `DataFrame` that contains columns of sequence and corresponding 
frequency.
    *         The schema of it will be:
    *          - `sequence: Seq[Seq[T]]` (T is the item type)
    *          - `freq: Long`
    */
   @Since("2.4.0")
-  def findFrequentSequentialPatterns(
-      dataset: Dataset[_],
-      sequenceCol: String,
-      minSupport: Double,
-      maxPatternLength: Int,
-      maxLocalProjDBSize: Long): DataFrame = {
-
-    val inputType = dataset.schema(sequenceCol).dataType
+  def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
+    val sequenceColParam = $(sequenceCol)
+    val inputType = dataset.schema(sequenceColParam).dataType
     require(inputType.isInstanceOf[ArrayType] &&
       inputType.asInstanceOf[ArrayType].elementType.isInstanceOf[ArrayType],
       s"The input column must be ArrayType and the array element type must 
also be ArrayType, " +
       s"but got $inputType.")
 
-
-    val data = dataset.select(sequenceCol)
-    val sequences = data.where(col(sequenceCol).isNotNull).rdd
+    val data = dataset.select(sequenceColParam)
+    val sequences = data.where(col(sequenceColParam).isNotNull).rdd
       .map(r => r.getAs[Seq[Seq[Any]]](0).map(_.toArray).toArray)
 
     val mllibPrefixSpan = new mllibPrefixSpan()
-      .setMinSupport(minSupport)
-      .setMaxPatternLength(maxPatternLength)
-      .setMaxLocalProjDBSize(maxLocalProjDBSize)
+      .setMinSupport($(minSupport))
+      .setMaxPatternLength($(maxPatternLength))
+      .setMaxLocalProjDBSize($(maxLocalProjDBSize))
 
     val rows = mllibPrefixSpan.run(sequences).freqSequences.map(f => 
Row(f.sequence, f.freq))
     val schema = StructType(Seq(
-      StructField("sequence", dataset.schema(sequenceCol).dataType, nullable = 
false),
+      StructField("sequence", dataset.schema(sequenceColParam).dataType, 
nullable = false),
       StructField("freq", LongType, nullable = false)))
     val freqSequences = dataset.sparkSession.createDataFrame(rows, schema)
 
     freqSequences
   }
 
+  @Since("2.4.0")
+  override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/df125062/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
index 9e53869..2252151 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/PrefixSpanSuite.scala
@@ -29,8 +29,11 @@ class PrefixSpanSuite extends MLTest {
 
   test("PrefixSpan projections with multiple partial starts") {
     val smallDataset = Seq(Seq(Seq(1, 2), Seq(1, 2, 3))).toDF("sequence")
-    val result = PrefixSpan.findFrequentSequentialPatterns(smallDataset, 
"sequence",
-      minSupport = 1.0, maxPatternLength = 2, maxLocalProjDBSize = 32000000)
+    val result = new PrefixSpan()
+      .setMinSupport(1.0)
+      .setMaxPatternLength(2)
+      .setMaxLocalProjDBSize(32000000)
+      .findFrequentSequentialPatterns(smallDataset)
       .as[(Seq[Seq[Int]], Long)].collect()
     val expected = Array(
       (Seq(Seq(1)), 1L),
@@ -90,8 +93,11 @@ class PrefixSpanSuite extends MLTest {
 
   test("PrefixSpan Integer type, variable-size itemsets") {
     val df = smallTestData.toDF("sequence")
-    val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
-      minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
+    val result = new PrefixSpan()
+      .setMinSupport(0.5)
+      .setMaxPatternLength(5)
+      .setMaxLocalProjDBSize(32000000)
+      .findFrequentSequentialPatterns(df)
       .as[(Seq[Seq[Int]], Long)].collect()
 
     compareResults[Int](smallTestDataExpectedResult, result)
@@ -99,8 +105,11 @@ class PrefixSpanSuite extends MLTest {
 
   test("PrefixSpan input row with nulls") {
     val df = (smallTestData :+ null).toDF("sequence")
-    val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
-      minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
+    val result = new PrefixSpan()
+      .setMinSupport(0.5)
+      .setMaxPatternLength(5)
+      .setMaxLocalProjDBSize(32000000)
+      .findFrequentSequentialPatterns(df)
       .as[(Seq[Seq[Int]], Long)].collect()
 
     compareResults[Int](smallTestDataExpectedResult, result)
@@ -111,8 +120,11 @@ class PrefixSpanSuite extends MLTest {
     val df = smallTestData
       .map(seq => seq.map(itemSet => itemSet.map(intToString)))
       .toDF("sequence")
-    val result = PrefixSpan.findFrequentSequentialPatterns(df, "sequence",
-      minSupport = 0.5, maxPatternLength = 5, maxLocalProjDBSize = 32000000)
+    val result = new PrefixSpan()
+      .setMinSupport(0.5)
+      .setMaxPatternLength(5)
+      .setMaxLocalProjDBSize(32000000)
+      .findFrequentSequentialPatterns(df)
       .as[(Seq[Seq[String]], Long)].collect()
 
     val expected = smallTestDataExpectedResult.map { case (seq, freq) =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to