Repository: spark Updated Branches: refs/heads/branch-1.3 c5f3b9e02 -> ba941ceb1
[SPARK-5900][MLLIB] make PIC and FPGrowth Java-friendly In the previous version, PIC stores clustering assignments as an `RDD[(Long, Int)]`. This is mapped to `RDD<Tuple2<Object, Object>>` in Java and hence Java users have to cast types manually. We should either create a new method called `javaAssignments` that returns `JavaRDD[(java.lang.Long, java.lang.Int)]` or wrap the result pair in a class. I chose the latter approach in this PR. Now assignments are stored as an `RDD[Assignment]`, where `Assignment` is a class with `id` and `cluster`. Similarly, in FPGrowth, the frequent itemsets are stored as an `RDD[(Array[Item], Long)]`, which is mapped to `RDD<Tuple2<Object, Object>>`. Though we provide a "Java-friendly" method `javaFreqItemsets` that returns `JavaRDD[(Array[Item], java.lang.Long)]`. It doesn't really work because `Array[Item]` is mapped to `Object` in Java. So in this PR I created a class `FreqItemset` to wrap the results. It has `items` and `freq`, as well as a `javaItems` method that returns `List<Item>` in Java. I'm not certain that the names I chose are proper: `Assignment`/`id`/`cluster` and `FreqItemset`/`items`/`freq`. Please let me know if there are better suggestions. CC: jkbradley Author: Xiangrui Meng <m...@databricks.com> Closes #4695 from mengxr/SPARK-5900 and squashes the following commits: 865b5ca [Xiangrui Meng] make Assignment serializable cffa96e [Xiangrui Meng] fix test 9c0e590 [Xiangrui Meng] remove unused Tuple2 1b9db3d [Xiangrui Meng] make PIC and FPGrowth Java-friendly (cherry picked from commit 0cfd2cebde0b7fac3779eda80d6e42223f8a3d9f) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ba941ceb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ba941ceb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ba941ceb Branch: refs/heads/branch-1.3 Commit: ba941ceb1f78b28ca5cfb18c770f4171b9c74b0a Parents: c5f3b9e Author: Xiangrui Meng <m...@databricks.com> Authored: Thu Feb 19 18:06:16 2015 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Thu Feb 19 18:06:26 2015 -0800 ---------------------------------------------------------------------- docs/mllib-clustering.md | 8 ++-- docs/mllib-frequent-pattern-mining.md | 12 +++--- .../examples/mllib/JavaFPGrowthExample.java | 8 ++-- .../JavaPowerIterationClusteringExample.java | 5 +-- .../spark/examples/mllib/FPGrowthExample.scala | 4 +- .../mllib/PowerIterationClusteringExample.scala | 8 +--- .../clustering/PowerIterationClustering.scala | 33 +++++++++++++--- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 41 ++++++++++++++------ .../spark/mllib/fpm/JavaFPGrowthSuite.java | 30 +++++--------- .../PowerIterationClusteringSuite.scala | 8 ++-- .../apache/spark/mllib/fpm/FPGrowthSuite.scala | 10 ++--- 11 files changed, 93 insertions(+), 74 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/docs/mllib-clustering.md ---------------------------------------------------------------------- diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 6e46a47..0b6db4f 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -314,8 +314,8 @@ val pic = new PowerIteartionClustering() .setMaxIterations(20) val model = pic.run(similarities) -model.assignments.foreach { case (vertexId, clusterId) => - println(s"$vertexId -> $clusterId") +model.assignments.foreach { a => + println(s"${a.id} -> ${a.cluster}") } {% endhighlight %} @@ -349,8 +349,8 @@ PowerIterationClustering pic = new PowerIterationClustering() .setMaxIterations(10); PowerIterationClusteringModel model = pic.run(similarities); -for (Tuple2<Object, Object> assignment: model.assignments().toJavaRDD().collect()) { - System.out.println(assignment._1() + " -> " + assignment._2()); +for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); } {% endhighlight %} </div> http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/docs/mllib-frequent-pattern-mining.md ---------------------------------------------------------------------- diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 0ff9738..9fd9be0 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -57,8 +57,8 @@ val fpg = new FPGrowth() .setNumPartitions(10) val model = fpg.run(transactions) -model.freqItemsets.collect().foreach { case (itemset, freq) => - println(itemset.mkString("[", ",", "]") + ", " + freq) +model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) } {% endhighlight %} @@ -74,10 +74,9 @@ Calling `FPGrowth.run` with transactions returns an that stores the frequent itemsets with their frequencies. {% highlight java %} -import java.util.Arrays; import java.util.List; -import scala.Tuple2; +import com.google.common.base.Joiner; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.mllib.fpm.FPGrowth; @@ -88,11 +87,10 @@ JavaRDD<List<String>> transactions = ... FPGrowth fpg = new FPGrowth() .setMinSupport(0.2) .setNumPartitions(10); - FPGrowthModel<String> model = fpg.run(transactions); -for (Tuple2<Object, Long> s: model.javaFreqItemsets().collect()) { - System.out.println("(" + Arrays.toString((Object[]) s._1()) + "): " + s._2()); +for (FPGrowth.FreqItemset<String> itemset: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); } {% endhighlight %} http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java index 0db572d..f50e802 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java @@ -18,10 +18,8 @@ package org.apache.spark.examples.mllib; import java.util.ArrayList; -import java.util.Arrays; - -import scala.Tuple2; +import com.google.common.base.Joiner; import com.google.common.collect.Lists; import org.apache.spark.SparkConf; @@ -54,8 +52,8 @@ public class JavaFPGrowthExample { .setMinSupport(0.3); FPGrowthModel<String> model = fpg.run(transactions); - for (Tuple2<Object, Long> s: model.javaFreqItemsets().collect()) { - System.out.println(Arrays.toString((Object[]) s._1()) + ", " + s._2()); + for (FPGrowth.FreqItemset<String> s: model.freqItemsets().toJavaRDD().collect()) { + System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); } sc.stop(); http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java index e9371de..6c6f976 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -17,7 +17,6 @@ package org.apache.spark.examples.mllib; -import scala.Tuple2; import scala.Tuple3; import com.google.common.collect.Lists; @@ -49,8 +48,8 @@ public class JavaPowerIterationClusteringExample { .setMaxIterations(10); PowerIterationClusteringModel model = pic.run(similarities); - for (Tuple2<Object, Object> assignment: model.assignments().toJavaRDD().collect()) { - System.out.println(assignment._1() + " -> " + assignment._2()); + for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) { + System.out.println(a.id() + " -> " + a.cluster()); } sc.stop(); http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index ae66107..aaae275 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -42,8 +42,8 @@ object FPGrowthExample { .setMinSupport(0.3) val model = fpg.run(transactions) - model.freqItemsets.collect().foreach { case (itemset, freq) => - println(itemset.mkString("[", ",", "]") + ", " + freq) + model.freqItemsets.collect().foreach { itemset => + println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) } sc.stop() http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index b2373ad..91c9772 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -44,8 +44,7 @@ import org.apache.spark.{SparkConf, SparkContext} * * Here is a sample run and output: * - * ./bin/run-example mllib.PowerIterationClusteringExample - * -k 3 --n 30 --maxIterations 15 + * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15 * * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14], * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29] @@ -103,7 +102,7 @@ object PowerIterationClusteringExample { .setMaxIterations(params.maxIterations) .run(circlesRdd) - val clusters = model.assignments.collect.groupBy(_._2).mapValues(_.map(_._1)) + val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id)) val assignments = clusters.toList.sortBy { case (k, v) => v.length} val assignmentsStr = assignments .map { case (k, v) => @@ -153,8 +152,5 @@ object PowerIterationClusteringExample { val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0) val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2) coeff * math.exp(expCoeff * ssquares) - // math.exp((p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2)) } - - } http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index 63d0334..1800239 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,9 +17,9 @@ package org.apache.spark.mllib.clustering -import org.apache.spark.api.java.JavaRDD import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.mllib.linalg.Vectors @@ -33,12 +33,12 @@ import org.apache.spark.util.random.XORShiftRandom * Model produced by [[PowerIterationClustering]]. * * @param k number of clusters - * @param assignments an RDD of (vertexID, clusterID) pairs + * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s */ @Experimental class PowerIterationClusteringModel( val k: Int, - val assignments: RDD[(Long, Int)]) extends Serializable + val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable /** * :: Experimental :: @@ -133,16 +133,33 @@ class PowerIterationClustering private[clustering] ( */ private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = { val v = powerIter(w, maxIterations) - val assignments = kMeans(v, k) + val assignments = kMeans(v, k).mapPartitions({ iter => + iter.map { case (id, cluster) => + new Assignment(id, cluster) + } + }, preservesPartitioning = true) new PowerIterationClusteringModel(k, assignments) } } -private[clustering] object PowerIterationClustering extends Logging { +@Experimental +object PowerIterationClustering extends Logging { + + /** + * :: Experimental :: + * Cluster assignment. + * @param id node id + * @param cluster assigned cluster id + */ + @Experimental + class Assignment(val id: Long, val cluster: Int) extends Serializable + /** * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). */ - def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = { + private[clustering] + def normalize(similarities: RDD[(Long, Long, Double)]) + : Graph[Double, Double] = { val edges = similarities.flatMap { case (i, j, s) => if (s < 0.0) { throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") @@ -173,6 +190,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @return a graph with edges representing W and vertices representing a random vector * with unit 1-norm */ + private[clustering] def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = { val r = g.vertices.mapPartitionsWithIndex( (part, iter) => { @@ -194,6 +212,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param g a graph representing the normalized affinity matrix (W) * @return a graph with edges representing W and vertices representing the degree vector */ + private[clustering] def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = { val sum = g.vertices.values.sum() val v0 = g.vertices.mapValues(_ / sum) @@ -207,6 +226,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param maxIterations maximum number of iterations * @return a [[VertexRDD]] representing the pseudo-eigenvector */ + private[clustering] def powerIter( g: Graph[Double, Double], maxIterations: Int): VertexRDD[Double] = { @@ -246,6 +266,7 @@ private[clustering] object PowerIterationClustering extends Logging { * @param k number of clusters * @return a [[VertexRDD]] representing the clustering assignments */ + private[clustering] def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = { val points = v.mapValues(x => Vectors.dense(x)).cache() val model = new KMeans() http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 3168d60..efa8459 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -26,8 +26,9 @@ import scala.reflect.ClassTag import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -35,18 +36,11 @@ import org.apache.spark.storage.StorageLevel * :: Experimental :: * * Model trained by [[FPGrowth]], which holds frequent itemsets. - * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs + * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] * @tparam Item item type */ @Experimental -class FPGrowthModel[Item: ClassTag]( - val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable { - - /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */ - def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = { - JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]] - } -} +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable /** * :: Experimental :: @@ -151,7 +145,7 @@ class FPGrowth private ( data: RDD[Array[Item]], minCount: Long, freqItems: Array[Item], - partitioner: Partitioner): RDD[(Array[Item], Long)] = { + partitioner: Partitioner): RDD[FreqItemset[Item]] = { val itemToRank = freqItems.zipWithIndex.toMap data.flatMap { transaction => genCondTransactions(transaction, itemToRank, partitioner) @@ -161,7 +155,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - (ranks.map(i => freqItems(i)).toArray, count) + new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) } } @@ -193,3 +187,26 @@ class FPGrowth private ( output } } + +/** + * :: Experimental :: + */ +@Experimental +object FPGrowth { + + /** + * Frequent itemset. + * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. + * @param freq frequency + * @tparam Item item type + */ + class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + + /** + * Returns items in a Java List. + */ + def javaItems: java.util.List[Item] = { + items.toList.asJava + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java ---------------------------------------------------------------------- diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 851707c..bd0edf2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.mllib.fpm; import java.io.Serializable; import java.util.ArrayList; +import java.util.List; import org.junit.After; import org.junit.Before; @@ -28,6 +29,7 @@ import static org.junit.Assert.*; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset; public class JavaFPGrowthSuite implements Serializable { private transient JavaSparkContext sc; @@ -55,30 +57,18 @@ public class JavaFPGrowthSuite implements Serializable { Lists.newArrayList("z".split(" ")), Lists.newArrayList("x z y r q t p".split(" "))), 2); - FPGrowth fpg = new FPGrowth(); - - FPGrowthModel<String> model6 = fpg - .setMinSupport(0.9) - .setNumPartitions(1) - .run(rdd); - assertEquals(0, model6.javaFreqItemsets().count()); - - FPGrowthModel<String> model3 = fpg + FPGrowthModel<String> model = new FPGrowth() .setMinSupport(0.5) .setNumPartitions(2) .run(rdd); - assertEquals(18, model3.javaFreqItemsets().count()); - FPGrowthModel<String> model2 = fpg - .setMinSupport(0.3) - .setNumPartitions(4) - .run(rdd); - assertEquals(54, model2.javaFreqItemsets().count()); + List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect(); + assertEquals(18, freqItemsets.size()); - FPGrowthModel<String> model1 = fpg - .setMinSupport(0.1) - .setNumPartitions(8) - .run(rdd); - assertEquals(625, model1.javaFreqItemsets().count()); + for (FreqItemset<String> itemset: freqItemsets) { + // Test return types. + List<String> items = itemset.javaItems(); + long freq = itemset.freq(); + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala index 03ecd9c..6315c03 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -51,8 +51,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext .setK(2) .run(sc.parallelize(similarities, 2)) val predictions = Array.fill(2)(mutable.Set.empty[Long]) - model.assignments.collect().foreach { case (i, c) => - predictions(c) += i + model.assignments.collect().foreach { a => + predictions(a.cluster) += a.id } assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) @@ -61,8 +61,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext .setInitializationMode("degree") .run(sc.parallelize(similarities, 2)) val predictions2 = Array.fill(2)(mutable.Set.empty[Long]) - model2.assignments.collect().foreach { case (i, c) => - predictions2(c) += i + model2.assignments.collect().foreach { a => + predictions2(a.cluster) += a.id } assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) } http://git-wip-us.apache.org/repos/asf/spark/blob/ba941ceb/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 6812828..bd5b9cc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -46,8 +46,8 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { .setMinSupport(0.5) .setNumPartitions(2) .run(rdd) - val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => - (items.toSet, count) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) } val expected = Set( (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L), @@ -96,10 +96,10 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext { .setMinSupport(0.5) .setNumPartitions(2) .run(rdd) - assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass, + assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, "frequent itemsets should use primitive arrays") - val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) => - (items.toSet, count) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) } val expected = Set( (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org