Repository: spark
Updated Branches:
  refs/heads/master 51d3c854c -> b34f7665d


[SPARK-19825][R][ML] spark.ml R API for FPGrowth

## What changes were proposed in this pull request?

Adds SparkR API for FPGrowth: 
[SPARK-19825](https://issues.apache.org/jira/browse/SPARK-19825):

- `spark.fpGrowth` -model training.
- `freqItemsets` and `associationRules` methods with new corresponding generics.
- Scala helper: `org.apache.spark.ml.r. FPGrowthWrapper`
- unit tests.

## How was this patch tested?

Feature specific unit tests.

Author: zero323 <zero...@users.noreply.github.com>

Closes #17170 from zero323/SPARK-19825.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b34f7665
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b34f7665
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b34f7665

Branch: refs/heads/master
Commit: b34f7665ddb0a40044b4c2bc7d351599c125cb13
Parents: 51d3c85
Author: zero323 <zero...@users.noreply.github.com>
Authored: Mon Apr 3 23:42:04 2017 -0700
Committer: Felix Cheung <felixche...@apache.org>
Committed: Mon Apr 3 23:42:04 2017 -0700

----------------------------------------------------------------------
 R/pkg/DESCRIPTION                               |   1 +
 R/pkg/NAMESPACE                                 |   5 +-
 R/pkg/R/generics.R                              |  12 ++
 R/pkg/R/mllib_fpm.R                             | 158 +++++++++++++++++++
 R/pkg/R/mllib_utils.R                           |   2 +
 R/pkg/inst/tests/testthat/test_mllib_fpm.R      |  83 ++++++++++
 .../org/apache/spark/ml/r/FPGrowthWrapper.scala |  86 ++++++++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   2 +
 8 files changed, 348 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/DESCRIPTION
----------------------------------------------------------------------
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 00dde64..f475ee8 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -44,6 +44,7 @@ Collate:
     'jvm.R'
     'mllib_classification.R'
     'mllib_clustering.R'
+    'mllib_fpm.R'
     'mllib_recommendation.R'
     'mllib_regression.R'
     'mllib_stat.R'

http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index c02046c..9b7e95c 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -66,7 +66,10 @@ exportMethods("glm",
               "spark.randomForest",
               "spark.gbt",
               "spark.bisectingKmeans",
-              "spark.svmLinear")
+              "spark.svmLinear",
+              "spark.fpGrowth",
+              "spark.freqItemsets",
+              "spark.associationRules")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 80283e4..945676c 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1445,6 +1445,18 @@ setGeneric("spark.posterior", function(object, newData) 
{ standardGeneric("spark
 #' @export
 setGeneric("spark.perplexity", function(object, data) { 
standardGeneric("spark.perplexity") })
 
+#' @rdname spark.fpGrowth
+#' @export
+setGeneric("spark.fpGrowth", function(data, ...) { 
standardGeneric("spark.fpGrowth") })
+
+#' @rdname spark.fpGrowth
+#' @export
+setGeneric("spark.freqItemsets", function(object) { 
standardGeneric("spark.freqItemsets") })
+
+#' @rdname spark.fpGrowth
+#' @export
+setGeneric("spark.associationRules", function(object) { 
standardGeneric("spark.associationRules") })
+
 #' @param object a fitted ML model object.
 #' @param path the directory where the model is saved.
 #' @param ... additional argument(s) passed to the method.

http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/R/mllib_fpm.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_fpm.R b/R/pkg/R/mllib_fpm.R
new file mode 100644
index 0000000..96251b2
--- /dev/null
+++ b/R/pkg/R/mllib_fpm.R
@@ -0,0 +1,158 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# mllib_fpm.R: Provides methods for MLlib frequent pattern mining algorithms 
integration
+
+#' S4 class that represents a FPGrowthModel
+#'
+#' @param jobj a Java object reference to the backing Scala FPGrowthModel
+#' @export
+#' @note FPGrowthModel since 2.2.0
+setClass("FPGrowthModel", slots = list(jobj = "jobj"))
+
+#' FP-growth
+#'
+#' A parallel FP-growth algorithm to mine frequent itemsets.
+#' For more details, see
+#' 
\href{https://spark.apache.org/docs/latest/mllib-frequent-pattern-mining.html#fp-growth}{
+#' FP-growth}.
+#'
+#' @param data A SparkDataFrame for training.
+#' @param minSupport Minimal support level.
+#' @param minConfidence Minimal confidence level.
+#' @param itemsCol Features column name.
+#' @param numPartitions Number of partitions used for fitting.
+#' @param ... additional argument(s) passed to the method.
+#' @return \code{spark.fpGrowth} returns a fitted FPGrowth model.
+#' @rdname spark.fpGrowth
+#' @name spark.fpGrowth
+#' @aliases spark.fpGrowth,SparkDataFrame-method
+#' @export
+#' @examples
+#' \dontrun{
+#' raw_data <- read.df(
+#'   "data/mllib/sample_fpgrowth.txt",
+#'   source = "csv",
+#'   schema = structType(structField("raw_items", "string")))
+#'
+#' data <- selectExpr(raw_data, "split(raw_items, ' ') as items")
+#' model <- spark.fpGrowth(data)
+#'
+#' # Show frequent itemsets
+#' frequent_itemsets <- spark.freqItemsets(model)
+#' showDF(frequent_itemsets)
+#'
+#' # Show association rules
+#' association_rules <- spark.associationRules(model)
+#' showDF(association_rules)
+#'
+#' # Predict on new data
+#' new_itemsets <- data.frame(items = c("t", "t,s"))
+#' new_data <- selectExpr(createDataFrame(new_itemsets), "split(items, ',') as 
items")
+#' predict(model, new_data)
+#'
+#' # Save and load model
+#' path <- "/path/to/model"
+#' write.ml(model, path)
+#' read.ml(path)
+#'
+#' # Optional arguments
+#' baskets_data <- selectExpr(createDataFrame(itemsets), "split(items, ',') as 
baskets")
+#' another_model <- spark.fpGrowth(data, minSupport = 0.1, minConfidence = 0.5,
+#'                                 itemsCol = "baskets", numPartitions = 10)
+#' }
+#' @note spark.fpGrowth since 2.2.0
+setMethod("spark.fpGrowth", signature(data = "SparkDataFrame"),
+          function(data, minSupport = 0.3, minConfidence = 0.8,
+                   itemsCol = "items", numPartitions = NULL) {
+            if (!is.numeric(minSupport) || minSupport < 0 || minSupport > 1) {
+              stop("minSupport should be a number [0, 1].")
+            }
+            if (!is.numeric(minConfidence) || minConfidence < 0 || 
minConfidence > 1) {
+              stop("minConfidence should be a number [0, 1].")
+            }
+            if (!is.null(numPartitions)) {
+              numPartitions <- as.integer(numPartitions)
+              stopifnot(numPartitions > 0)
+            }
+
+            jobj <- callJStatic("org.apache.spark.ml.r.FPGrowthWrapper", "fit",
+                                data@sdf, as.numeric(minSupport), 
as.numeric(minConfidence),
+                                itemsCol, numPartitions)
+            new("FPGrowthModel", jobj = jobj)
+          })
+
+# Get frequent itemsets.
+
+#' @param object a fitted FPGrowth model.
+#' @return A \code{SparkDataFrame} with frequent itemsets.
+#'         The \code{SparkDataFrame} contains two columns:
+#'         \code{items} (an array of the same type as the input column)
+#'         and \code{freq} (frequency of the itemset).
+#' @rdname spark.fpGrowth
+#' @aliases freqItemsets,FPGrowthModel-method
+#' @export
+#' @note spark.freqItemsets(FPGrowthModel) since 2.2.0
+setMethod("spark.freqItemsets", signature(object = "FPGrowthModel"),
+          function(object) {
+            dataFrame(callJMethod(object@jobj, "freqItemsets"))
+          })
+
+# Get association rules.
+
+#' @return A \code{SparkDataFrame} with association rules.
+#'         The \code{SparkDataFrame} contains three columns:
+#'         \code{antecedent} (an array of the same type as the input column),
+#'         \code{consequent} (an array of the same type as the input column),
+#'         and \code{condfidence} (confidence).
+#' @rdname spark.fpGrowth
+#' @aliases associationRules,FPGrowthModel-method
+#' @export
+#' @note spark.associationRules(FPGrowthModel) since 2.2.0
+setMethod("spark.associationRules", signature(object = "FPGrowthModel"),
+          function(object) {
+            dataFrame(callJMethod(object@jobj, "associationRules"))
+          })
+
+#  Makes predictions based on generated association rules
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted values.
+#' @rdname spark.fpGrowth
+#' @aliases predict,FPGrowthModel-method
+#' @export
+#' @note predict(FPGrowthModel) since 2.2.0
+setMethod("predict", signature(object = "FPGrowthModel"),
+          function(object, newData) {
+            predict_internal(object, newData)
+          })
+
+#  Saves the FPGrowth model to the output path.
+
+#' @param path the directory where the model is saved.
+#' @param overwrite logical value indicating whether to overwrite if the 
output path
+#'                  already exists. Default is FALSE which means throw 
exception
+#'                  if the output path exists.
+#' @rdname spark.fpGrowth
+#' @aliases write.ml,FPGrowthModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(FPGrowthModel, character) since 2.2.0
+setMethod("write.ml", signature(object = "FPGrowthModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            write_internal(object, path, overwrite)
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/R/mllib_utils.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R
index 04a0a6f..5dfef86 100644
--- a/R/pkg/R/mllib_utils.R
+++ b/R/pkg/R/mllib_utils.R
@@ -118,6 +118,8 @@ read.ml <- function(path) {
     new("BisectingKMeansModel", jobj = jobj)
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LinearSVCWrapper")) {
     new("LinearSVCModel", jobj = jobj)
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.FPGrowthWrapper")) {
+    new("FPGrowthModel", jobj = jobj)
   } else {
     stop("Unsupported model: ", jobj)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/R/pkg/inst/tests/testthat/test_mllib_fpm.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_fpm.R 
b/R/pkg/inst/tests/testthat/test_mllib_fpm.R
new file mode 100644
index 0000000..c38f113
--- /dev/null
+++ b/R/pkg/inst/tests/testthat/test_mllib_fpm.R
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+library(testthat)
+
+context("MLlib frequent pattern mining")
+
+# Tests for MLlib frequent pattern mining algorithms in SparkR
+sparkSession <- sparkR.session(enableHiveSupport = FALSE)
+
+test_that("spark.fpGrowth", {
+  data <- selectExpr(createDataFrame(data.frame(items = c(
+    "1,2",
+    "1,2",
+    "1,2,3",
+    "1,3"
+  ))), "split(items, ',') as items")
+
+  model <- spark.fpGrowth(data, minSupport = 0.3, minConfidence = 0.8, 
numPartitions = 1)
+
+  itemsets <- collect(spark.freqItemsets(model))
+
+  expected_itemsets <- data.frame(
+    items = I(list(list("3"), list("3", "1"), list("2"), list("2", "1"), 
list("1"))),
+    freq = c(2, 2, 3, 3, 4)
+  )
+
+  expect_equivalent(expected_itemsets, itemsets)
+
+  expected_association_rules <- data.frame(
+    antecedent = I(list(list("2"), list("3"))),
+    consequent = I(list(list("1"), list("1"))),
+    confidence = c(1, 1)
+  )
+
+  expect_equivalent(expected_association_rules, 
collect(spark.associationRules(model)))
+
+  new_data <- selectExpr(createDataFrame(data.frame(items = c(
+    "1,2",
+    "1,3",
+    "2,3"
+  ))), "split(items, ',') as items")
+
+  expected_predictions <- data.frame(
+    items = I(list(list("1", "2"), list("1", "3"), list("2", "3"))),
+    prediction = I(list(list(), list(), list("1")))
+  )
+
+  expect_equivalent(expected_predictions, collect(predict(model, new_data)))
+
+  modelPath <- tempfile(pattern = "spark-fpm", fileext = ".tmp")
+  write.ml(model, modelPath, overwrite = TRUE)
+  loaded_model <- read.ml(modelPath)
+
+  expect_equivalent(
+    itemsets,
+    collect(spark.freqItemsets(loaded_model)))
+
+  unlink(modelPath)
+
+  model_without_numpartitions <- spark.fpGrowth(data, minSupport = 0.3, 
minConfidence = 0.8)
+  expect_equal(
+    count(spark.freqItemsets(model_without_numpartitions)),
+    count(spark.freqItemsets(model))
+  )
+
+})
+
+sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
new file mode 100644
index 0000000..b8151d8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/FPGrowthWrapper.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.fpm.{FPGrowth, FPGrowthModel}
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class FPGrowthWrapper private (val fpGrowthModel: FPGrowthModel) 
extends MLWritable {
+  def freqItemsets: DataFrame = fpGrowthModel.freqItemsets
+  def associationRules: DataFrame = fpGrowthModel.associationRules
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    fpGrowthModel.transform(dataset)
+  }
+
+  override def write: MLWriter = new 
FPGrowthWrapper.FPGrowthWrapperWriter(this)
+}
+
+private[r] object FPGrowthWrapper extends MLReadable[FPGrowthWrapper] {
+
+  def fit(
+           data: DataFrame,
+           minSupport: Double,
+           minConfidence: Double,
+           itemsCol: String,
+           numPartitions: Integer): FPGrowthWrapper = {
+    val fpGrowth = new FPGrowth()
+      .setMinSupport(minSupport)
+      .setMinConfidence(minConfidence)
+      .setItemsCol(itemsCol)
+
+    if (numPartitions != null && numPartitions > 0) {
+      fpGrowth.setNumPartitions(numPartitions)
+    }
+
+    val fpGrowthModel = fpGrowth.fit(data)
+
+    new FPGrowthWrapper(fpGrowthModel)
+  }
+
+  override def read: MLReader[FPGrowthWrapper] = new FPGrowthWrapperReader
+
+  class FPGrowthWrapperReader extends MLReader[FPGrowthWrapper] {
+    override def load(path: String): FPGrowthWrapper = {
+      val modelPath = new Path(path, "model").toString
+      val fPGrowthModel = FPGrowthModel.load(modelPath)
+
+      new FPGrowthWrapper(fPGrowthModel)
+    }
+  }
+
+  class FPGrowthWrapperWriter(instance: FPGrowthWrapper) extends MLWriter {
+    override protected def saveImpl(path: String): Unit = {
+      val modelPath = new Path(path, "model").toString
+      val rMetadataPath = new Path(path, "rMetadata").toString
+
+      val rMetadataJson: String = compact(render(
+        "class" -> instance.getClass.getName
+      ))
+
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+      instance.fpGrowthModel.save(modelPath)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b34f7665/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala 
b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 358e522..b30ce12 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -68,6 +68,8 @@ private[r] object RWrappers extends MLReader[Object] {
         BisectingKMeansWrapper.load(path)
       case "org.apache.spark.ml.r.LinearSVCWrapper" =>
         LinearSVCWrapper.load(path)
+      case "org.apache.spark.ml.r.FPGrowthWrapper" =>
+        FPGrowthWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR read.ml does not support load 
$className")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to