Author: luc Date: Sun Jun 5 16:27:53 2011 New Revision: 1132448 URL: http://svn.apache.org/viewvc?rev=1132448&view=rev Log: Improved k-means++ clustering performances and initial cluster center choice.
JIRA: MATH-584 Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java commons/proper/math/trunk/src/site/xdoc/changes.xml Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java?rev=1132448&r1=1132447&r2=1132448&view=diff ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java (original) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java Sun Jun 5 16:27:53 2011 @@ -19,6 +19,7 @@ package org.apache.commons.math.stat.clu import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Random; @@ -193,41 +194,121 @@ public class KMeansPlusPlusClusterer<T e private static <T extends Clusterable<T>> List<Cluster<T>> chooseInitialCenters(final Collection<T> points, final int k, final Random random) { - final List<T> pointSet = new ArrayList<T>(points); + // Convert to list for indexed access. Make it unmodifiable, since removal of items + // would screw up the logic of this method. + final List<T> pointList = Collections.unmodifiableList(new ArrayList<T> (points)); + + // The number of points in the list. + final int numPoints = pointList.size(); + + // Set the corresponding element in this array to indicate when + // elements of pointList are no longer available. + final boolean[] taken = new boolean[numPoints]; + + // The resulting list of initial centers. final List<Cluster<T>> resultSet = new ArrayList<Cluster<T>>(); // Choose one center uniformly at random from among the data points. - final T firstPoint = pointSet.remove(random.nextInt(pointSet.size())); + final int firstPointIndex = random.nextInt(numPoints); + + final T firstPoint = pointList.get(firstPointIndex); + resultSet.add(new Cluster<T>(firstPoint)); - final double[] dx2 = new double[pointSet.size()]; + // Must mark it as taken + taken[firstPointIndex] = true; + + // To keep track of the minimum distance squared of elements of + // pointList to elements of resultSet. + final double[] minDistSquared = new double[numPoints]; + + // Initialize the elements. Since the only point in resultSet is firstPoint, + // this is very easy. + for (int i = 0; i < numPoints; i++) { + if (i != firstPointIndex) { // That point isn't considered + double d = firstPoint.distanceFrom(pointList.get(i)); + minDistSquared[i] = d*d; + } + } + while (resultSet.size() < k) { - // For each data point x, compute D(x), the distance between x and - // the nearest center that has already been chosen. - double sum = 0; - for (int i = 0; i < pointSet.size(); i++) { - final T p = pointSet.get(i); - int nearestClusterIndex = getNearestCluster(resultSet, p); - final Cluster<T> nearest = resultSet.get(nearestClusterIndex); - final double d = p.distanceFrom(nearest.getCenter()); - sum += d * d; - dx2[i] = sum; + + // Sum up the squared distances for the points in pointList not + // already taken. + double distSqSum = 0.0; + + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + distSqSum += minDistSquared[i]; + } } // Add one new data point as a center. Each point x is chosen with // probability proportional to D(x)2 - final double r = random.nextDouble() * sum; - for (int i = 0 ; i < dx2.length; i++) { - if (dx2[i] >= r) { - final T p = pointSet.remove(i); - resultSet.add(new Cluster<T>(p)); - break; + final double r = random.nextDouble() * distSqSum; + + // The index of the next point to be added to the resultSet. + int nextPointIndex = -1; + + // Sum through the squared min distances again, stopping when + // sum >= r. + double sum = 0.0; + for (int i = 0; i < numPoints; i++) { + if (!taken[i]) { + sum += minDistSquared[i]; + if (sum >= r) { + nextPointIndex = i; + break; + } } } + + // If it's not set to >= 0, the point wasn't found in the previous + // for loop, probably because distances are extremely small. Just pick + // the last available point. + if (nextPointIndex == -1) { + for (int i = numPoints - 1; i >= 0; i--) { + if (!taken[i]) { + nextPointIndex = i; + break; + } + } + } + + // We found one. + if (nextPointIndex >= 0) { + + final T p = pointList.get(nextPointIndex); + + resultSet.add(new Cluster<T> (p)); + + // Mark it as taken. + taken[nextPointIndex] = true; + + if (resultSet.size() < k) { + // Now update elements of minDistSquared. We only have to compute + // the distance to the new center to do this. + for (int j = 0; j < numPoints; j++) { + // Only have to worry about the points still not taken. + if (!taken[j]) { + double d = p.distanceFrom(pointList.get(j)); + double d2 = d * d; + if (d2 < minDistSquared[j]) { + minDistSquared[j] = d2; + } + } + } + } + + } else { + // None found -- + // Break from the while loop to prevent + // an infinite loop. + break; + } } return resultSet; - } /** Modified: commons/proper/math/trunk/src/site/xdoc/changes.xml URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/site/xdoc/changes.xml?rev=1132448&r1=1132447&r2=1132448&view=diff ============================================================================== --- commons/proper/math/trunk/src/site/xdoc/changes.xml (original) +++ commons/proper/math/trunk/src/site/xdoc/changes.xml Sun Jun 5 16:27:53 2011 @@ -52,6 +52,9 @@ The <action> type attribute can be add,u If the output is not quite correct, check for invisible trailing spaces! --> <release version="3.0" date="TBD" description="TBD"> + <action dev="luc" type="fix" issue="MATH-584" due-to="Randall Scarberry"> + Improved k-means++ clustering performances and initial cluster center choice. + </action> <action dev="luc" type="fix" issue="MATH-504" due-to="X. B."> Fixed tricube function implementation in Loess interpolator. </action>