Artem Barger created MATH-1378: ---------------------------------- Summary: KMeansPlusPlusClusterer optimize seeding procedure, by computing sum of squared distances outside the loop. Key: MATH-1378 URL: https://issues.apache.org/jira/browse/MATH-1378 Project: Commons Math Issue Type: Improvement Reporter: Artem Barger Assignee: Artem Barger
Currently in KMeansPlusPlusClusterer class, function which implements initial clusters seeding *chooseInitialCenters*, has following computation executed inside the while loop cycle: {code} while (resultSet.size() < k) { // Sum up the squared distances for the points in pointList not // already taken. double distSqSum = 0.0; for (int i = 0; i < numPoints; i++) { if (!taken[i]) { distSqSum += minDistSquared[i]; } } // Rest skipped for simplicity {code} While computation of this sum could be produced once outside the loop and latter adjusted according to the values of minimum distances to the centers set. E.g.: {code} // Sum up the squared distances for the points in pointList not // already taken. double distSqSum = 0.0; // There is no need to compute sum of squared distances within the "while" loop // we can compute initial value ones and maintain deltas in the loop. for (int i = 0; i < numPoints; i++) { if (!taken[i]) { distSqSum += minDistSquared[i]; } } while (resultSet.size() < k) { // Add one new data point as a center. Each point x is chosen with // probability proportional to D(x)2 final double r = random.nextDouble() * distSqSum; // The index of the next point to be added to the resultSet. int nextPointIndex = -1; // Sum through the squared min distances again, stopping when // sum >= r. double sum = 0.0; for (int i = 0; i < numPoints; i++) { if (!taken[i]) { sum += minDistSquared[i]; if (sum >= r) { nextPointIndex = i; break; } } } // If it's not set to >= 0, the point wasn't found in the previous // for loop, probably because distances are extremely small. Just pick // the last available point. if (nextPointIndex == -1) { for (int i = numPoints - 1; i >= 0; i--) { if (!taken[i]) { nextPointIndex = i; break; } } } // We found one. if (nextPointIndex >= 0) { final T p = pointList.get(nextPointIndex); resultSet.add(new CentroidCluster<T> (p)); // Mark it as taken. taken[nextPointIndex] = true; if (resultSet.size() < k) { // Now update elements of minDistSquared. We only have to compute // the distance to the new center to do this. for (int j = 0; j < numPoints; j++) { // Only have to worry about the points still not taken. if (!taken[j]) { double d = distance(p, pointList.get(j)); // Subtracting the old value. distSqSum -= minDistSquared[j]; // Update minimum distance. minDistSquared[j] = FastMath.min(d*d, minDistSquared[j]); // Adjust the overall sum of squared distances. distSqSum += minDistSquared[j]; } } } } else { // None found -- // Break from the while loop to prevent // an infinite loop. break; } } {code} -- This message was sent by Atlassian JIRA (v6.3.4#6332)