[ https://issues.apache.org/jira/browse/FLINK-2131?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15537304#comment-15537304 ]
ASF GitHub Bot commented on FLINK-2131: --------------------------------------- Github user skonto commented on a diff in the pull request: https://github.com/apache/flink/pull/757#discussion_r81430683 --- Diff: flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/clustering/KMeansITSuite.scala --- @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering + +import org.apache.flink.api.scala._ +import org.apache.flink.ml._ +import org.apache.flink.ml.math +import org.apache.flink.ml.math.DenseVector +import org.apache.flink.test.util.FlinkTestBase +import org.scalatest.{FlatSpec, Matchers} + +class KMeansITSuite extends FlatSpec with Matchers with FlinkTestBase { + + behavior of "The KMeans implementation" + + def fixture = new { + val env = ExecutionEnvironment.getExecutionEnvironment + val kmeans = KMeans(). + setInitialCentroids(ClusteringData.centroidData). + setNumIterations(ClusteringData.iterations) + + val trainingDS = env.fromCollection(ClusteringData.trainingData) + + kmeans.fit(trainingDS) + } + + it should "cluster data points into 'K' cluster centers" in { + val f = fixture + + val centroidsResult = f.kmeans.centroids.get.collect().apply(0) + + val centroidsExpected = ClusteringData.expectedCentroids + + // the sizes must match + centroidsResult.length should be === centroidsExpected.length + + // create a lookup table for better matching + val expectedMap = centroidsExpected map (e => e.label->e.vector.asInstanceOf[DenseVector]) toMap + + // each of the results must be in lookup table + centroidsResult.iterator.foreach(result => { + val expectedVector = expectedMap.get(result.label).get + + // the type must match (not None) + expectedVector shouldBe a [math.DenseVector] + + val expectedData = expectedVector.asInstanceOf[DenseVector].data + val resultData = result.vector.asInstanceOf[DenseVector].data + + // match the individual values of the vector + expectedData zip resultData foreach { + case (expectedVector, entryVector) => + entryVector should be(expectedVector +- 0.00001) + } + }) + } + + it should "predict points to cluster centers" in { + val f = fixture + + val vectorsWithExpectedLabels = ClusteringData.testData + // create a lookup table for better matching + val expectedMap = vectorsWithExpectedLabels map (v => + v.vector.asInstanceOf[DenseVector] -> v.label + ) toMap + + // calculate the vector to cluster mapping on the plain vectors + val plainVectors = vectorsWithExpectedLabels.map(v => v.vector) + val predictedVectors = f.kmeans.predict(f.env.fromCollection(plainVectors)) + + // check if all vectors were labeled correctly + predictedVectors.collect() foreach (result => { + val expectedLabel = expectedMap.get(result._1.asInstanceOf[DenseVector]).get + result._2 should be(expectedLabel) + }) + + } + + it should "initialize k cluster centers randomly" in { + + val env = ExecutionEnvironment.getExecutionEnvironment + val kmeans = KMeans() + .setNumClusters(10) + .setNumIterations(ClusteringData.iterations) + .setInitializationStrategy("random") + + val trainingDS = env.fromCollection(ClusteringData.trainingData) + kmeans.fit(trainingDS) + + println(trainingDS.mapWithBcVariable(kmeans.centroids.get) { --- End diff -- assertion? > Add Initialization schemes for K-means clustering > ------------------------------------------------- > > Key: FLINK-2131 > URL: https://issues.apache.org/jira/browse/FLINK-2131 > Project: Flink > Issue Type: Task > Components: Machine Learning Library > Reporter: Sachin Goel > Assignee: Sachin Goel > > The Lloyd's [KMeans] algorithm takes initial centroids as its input. However, > in case the user doesn't provide the initial centers, they may ask for a > particular initialization scheme to be followed. The most commonly used are > these: > 1. Random initialization: Self-explanatory > 2. kmeans++ initialization: http://ilpubs.stanford.edu:8090/778/1/2006-13.pdf > 3. kmeans|| : http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf > For very large data sets, or for large values of k, the kmeans|| method is > preferred as it provides the same approximation guarantees as kmeans++ and > requires lesser number of passes over the input data. -- This message was sent by Atlassian JIRA (v6.3.4#6332)