http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java b/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java deleted file mode 100644 index 03a3000..0000000 --- a/integration/src/main/java/org/apache/mahout/classifier/ConfusionMatrixDumper.java +++ /dev/null @@ -1,425 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier; - -import com.google.common.collect.Lists; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.SequenceFile; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.util.ToolRunner; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.MatrixWritable; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.io.PrintStream; -import java.util.Iterator; -import java.util.List; -import java.util.Map; - -/** - * Export a ConfusionMatrix in various text formats: ToString version Grayscale HTML table Summary HTML table - * Table of counts all with optional HTML wrappers - * - * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair - * - * Intended to consume ConfusionMatrix SequenceFile output by Bayes TestClassifier class - */ -public final class ConfusionMatrixDumper extends AbstractJob { - - private static final String TAB_SEPARATOR = "|"; - - // HTML wrapper - default CSS - private static final String HEADER = "<html>" - + "<head>\n" - + "<title>TITLE</title>\n" - + "</head>" - + "<body>\n" - + "<style type='text/css'> \n" - + "table\n" - + "{\n" - + "border:3px solid black; text-align:left;\n" - + "}\n" - + "th.normalHeader\n" - + "{\n" - + "border:1px solid black;border-collapse:collapse;text-align:center;" - + "background-color:white\n" - + "}\n" - + "th.tallHeader\n" - + "{\n" - + "border:1px solid black;border-collapse:collapse;text-align:center;" - + "background-color:white; height:6em\n" - + "}\n" - + "tr.label\n" - + "{\n" - + "border:1px solid black;border-collapse:collapse;text-align:center;" - + "background-color:white\n" - + "}\n" - + "tr.row\n" - + "{\n" - + "border:1px solid gray;text-align:center;background-color:snow\n" - + "}\n" - + "td\n" - + "{\n" - + "min-width:2em\n" - + "}\n" - + "td.cell\n" - + "{\n" - + "border:1px solid black;text-align:right;background-color:snow\n" - + "}\n" - + "td.empty\n" - + "{\n" - + "border:0px;text-align:right;background-color:snow\n" - + "}\n" - + "td.white\n" - + "{\n" - + "border:0px solid black;text-align:right;background-color:white\n" - + "}\n" - + "td.black\n" - + "{\n" - + "border:0px solid red;text-align:right;background-color:black\n" - + "}\n" - + "td.gray1\n" - + "{\n" - + "border:0px solid green;text-align:right; background-color:LightGray\n" - + "}\n" + "td.gray2\n" + "{\n" - + "border:0px solid blue;text-align:right;background-color:gray\n" - + "}\n" + "td.gray3\n" + "{\n" - + "border:0px solid red;text-align:right;background-color:DarkGray\n" - + "}\n" + "th" + "{\n" + " text-align: center;\n" - + " vertical-align: bottom;\n" - + " padding-bottom: 3px;\n" + " padding-left: 5px;\n" - + " padding-right: 5px;\n" + "}\n" + " .verticalText\n" - + " {\n" + " text-align: center;\n" - + " vertical-align: middle;\n" + " width: 20px;\n" - + " margin: 0px;\n" + " padding: 0px;\n" - + " padding-left: 3px;\n" + " padding-right: 3px;\n" - + " padding-top: 10px;\n" + " white-space: nowrap;\n" - + " -webkit-transform: rotate(-90deg); \n" - + " -moz-transform: rotate(-90deg); \n" + " };\n" - + "</style>\n"; - private static final String FOOTER = "</html></body>"; - - // CSS style names. - private static final String CSS_TABLE = "table"; - private static final String CSS_LABEL = "label"; - private static final String CSS_TALL_HEADER = "tall"; - private static final String CSS_VERTICAL = "verticalText"; - private static final String CSS_CELL = "cell"; - private static final String CSS_EMPTY = "empty"; - private static final String[] CSS_GRAY_CELLS = {"white", "gray1", "gray2", "gray3", "black"}; - - private ConfusionMatrixDumper() {} - - public static void main(String[] args) throws Exception { - ToolRunner.run(new ConfusionMatrixDumper(), args); - } - - @Override - public int run(String[] args) throws IOException { - addInputOption(); - addOption("output", "o", "Output path", null); // AbstractJob output feature requires param - addOption(DefaultOptionCreator.overwriteOption().create()); - addFlag("html", null, "Create complete HTML page"); - addFlag("text", null, "Dump simple text"); - Map<String,List<String>> parsedArgs = parseArguments(args); - if (parsedArgs == null) { - return -1; - } - - Path inputPath = getInputPath(); - String outputFile = hasOption("output") ? getOption("output") : null; - boolean text = parsedArgs.containsKey("--text"); - boolean wrapHtml = parsedArgs.containsKey("--html"); - PrintStream out = getPrintStream(outputFile); - if (text) { - exportText(inputPath, out); - } else { - exportTable(inputPath, out, wrapHtml); - } - out.flush(); - if (out != System.out) { - out.close(); - } - return 0; - } - - private static void exportText(Path inputPath, PrintStream out) throws IOException { - MatrixWritable mw = new MatrixWritable(); - Text key = new Text(); - readSeqFile(inputPath, key, mw); - Matrix m = mw.get(); - ConfusionMatrix cm = new ConfusionMatrix(m); - out.println(String.format("%-40s", "Label") + TAB_SEPARATOR + String.format("%-10s", "Total") - + TAB_SEPARATOR + String.format("%-10s", "Correct") + TAB_SEPARATOR - + String.format("%-6s", "%") + TAB_SEPARATOR); - out.println(String.format("%-70s", "-").replace(' ', '-')); - List<String> labels = stripDefault(cm); - for (String label : labels) { - int correct = cm.getCorrect(label); - double accuracy = cm.getAccuracy(label); - int count = getCount(cm, label); - out.println(String.format("%-40s", label) + TAB_SEPARATOR + String.format("%-10s", count) - + TAB_SEPARATOR + String.format("%-10s", correct) + TAB_SEPARATOR - + String.format("%-6s", (int) Math.round(accuracy)) + TAB_SEPARATOR); - } - out.println(String.format("%-70s", "-").replace(' ', '-')); - out.println(cm.toString()); - } - - private static void exportTable(Path inputPath, PrintStream out, boolean wrapHtml) throws IOException { - MatrixWritable mw = new MatrixWritable(); - Text key = new Text(); - readSeqFile(inputPath, key, mw); - String fileName = inputPath.getName(); - fileName = fileName.substring(fileName.lastIndexOf('/') + 1, fileName.length()); - Matrix m = mw.get(); - ConfusionMatrix cm = new ConfusionMatrix(m); - if (wrapHtml) { - printHeader(out, fileName); - } - out.println("<p/>"); - printSummaryTable(cm, out); - out.println("<p/>"); - printGrayTable(cm, out); - out.println("<p/>"); - printCountsTable(cm, out); - out.println("<p/>"); - printTextInBox(cm, out); - out.println("<p/>"); - if (wrapHtml) { - printFooter(out); - } - } - - private static List<String> stripDefault(ConfusionMatrix cm) { - List<String> stripped = Lists.newArrayList(cm.getLabels().iterator()); - String defaultLabel = cm.getDefaultLabel(); - int unclassified = cm.getTotal(defaultLabel); - if (unclassified > 0) { - return stripped; - } - stripped.remove(defaultLabel); - return stripped; - } - - // TODO: test - this should work with HDFS files - private static void readSeqFile(Path path, Text key, MatrixWritable m) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(conf); - SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf); - reader.next(key, m); - } - - // TODO: test - this might not work with HDFS files? - // after all, it does no seeks - private static PrintStream getPrintStream(String outputFilename) throws IOException { - if (outputFilename != null) { - File outputFile = new File(outputFilename); - if (outputFile.exists()) { - outputFile.delete(); - } - outputFile.createNewFile(); - OutputStream os = new FileOutputStream(outputFile); - return new PrintStream(os, false, Charsets.UTF_8.displayName()); - } else { - return System.out; - } - } - - private static int getLabelTotal(ConfusionMatrix cm, String rowLabel) { - Iterator<String> iter = cm.getLabels().iterator(); - int count = 0; - while (iter.hasNext()) { - count += cm.getCount(rowLabel, iter.next()); - } - return count; - } - - // HTML generator code - - private static void printTextInBox(ConfusionMatrix cm, PrintStream out) { - out.println("<div style='width:90%;overflow:scroll;'>"); - out.println("<pre>"); - out.println(cm.toString()); - out.println("</pre>"); - out.println("</div>"); - } - - public static void printSummaryTable(ConfusionMatrix cm, PrintStream out) { - format("<table class='%s'>\n", out, CSS_TABLE); - format("<tr class='%s'>", out, CSS_LABEL); - out.println("<td>Label</td><td>Total</td><td>Correct</td><td>%</td>"); - out.println("</tr>"); - List<String> labels = stripDefault(cm); - for (String label : labels) { - printSummaryRow(cm, out, label); - } - out.println("</table>"); - } - - private static void printSummaryRow(ConfusionMatrix cm, PrintStream out, String label) { - format("<tr class='%s'>", out, CSS_CELL); - int correct = cm.getCorrect(label); - double accuracy = cm.getAccuracy(label); - int count = getCount(cm, label); - format("<td class='%s'>%s</td><td>%d</td><td>%d</td><td>%d</td>", out, CSS_CELL, label, count, correct, - (int) Math.round(accuracy)); - out.println("</tr>"); - } - - private static int getCount(ConfusionMatrix cm, String label) { - int count = 0; - for (String s : cm.getLabels()) { - count += cm.getCount(label, s); - } - return count; - } - - public static void printGrayTable(ConfusionMatrix cm, PrintStream out) { - format("<table class='%s'>\n", out, CSS_TABLE); - printCountsHeader(cm, out, true); - printGrayRows(cm, out); - out.println("</table>"); - } - - /** - * Print each value in a four-value grayscale based on count/max. Gives a mostly white matrix with grays in - * misclassified, and black in diagonal. TODO: Using the sqrt(count/max) as the rating is more stringent - */ - private static void printGrayRows(ConfusionMatrix cm, PrintStream out) { - List<String> labels = stripDefault(cm); - for (String label : labels) { - printGrayRow(cm, out, labels, label); - } - } - - private static void printGrayRow(ConfusionMatrix cm, - PrintStream out, - Iterable<String> labels, - String rowLabel) { - format("<tr class='%s'>", out, CSS_LABEL); - format("<td>%s</td>", out, rowLabel); - int total = getLabelTotal(cm, rowLabel); - for (String columnLabel : labels) { - printGrayCell(cm, out, total, rowLabel, columnLabel); - } - out.println("</tr>"); - } - - // assign white/light/medium/dark to 0,1/4,1/2,3/4 of total number of inputs - // assign black to count = total, meaning complete success - // alternative rating is to use sqrt(total) instead of total - this is more drastic - private static void printGrayCell(ConfusionMatrix cm, - PrintStream out, - int total, - String rowLabel, - String columnLabel) { - - int count = cm.getCount(rowLabel, columnLabel); - if (count == 0) { - out.format("<td class='%s'/>", CSS_EMPTY); - } else { - // 0 is white, full is black, everything else gray - int rating = (int) ((count / (double) total) * 4); - String css = CSS_GRAY_CELLS[rating]; - format("<td class='%s' title='%s'>%s</td>", out, css, columnLabel, count); - } - } - - public static void printCountsTable(ConfusionMatrix cm, PrintStream out) { - format("<table class='%s'>\n", out, CSS_TABLE); - printCountsHeader(cm, out, false); - printCountsRows(cm, out); - out.println("</table>"); - } - - private static void printCountsRows(ConfusionMatrix cm, PrintStream out) { - List<String> labels = stripDefault(cm); - for (String label : labels) { - printCountsRow(cm, out, labels, label); - } - } - - private static void printCountsRow(ConfusionMatrix cm, - PrintStream out, - Iterable<String> labels, - String rowLabel) { - out.println("<tr>"); - format("<td class='%s'>%s</td>", out, CSS_LABEL, rowLabel); - for (String columnLabel : labels) { - printCountsCell(cm, out, rowLabel, columnLabel); - } - out.println("</tr>"); - } - - private static void printCountsCell(ConfusionMatrix cm, PrintStream out, String rowLabel, String columnLabel) { - int count = cm.getCount(rowLabel, columnLabel); - String s = count == 0 ? "" : Integer.toString(count); - format("<td class='%s' title='%s'>%s</td>", out, CSS_CELL, columnLabel, s); - } - - private static void printCountsHeader(ConfusionMatrix cm, PrintStream out, boolean vertical) { - List<String> labels = stripDefault(cm); - int longest = getLongestHeader(labels); - if (vertical) { - // do vertical - rotation is a bitch - out.format("<tr class='%s' style='height:%dem'><th> </th>%n", CSS_TALL_HEADER, longest / 2); - for (String label : labels) { - out.format("<th><div class='%s'>%s</div></th>", CSS_VERTICAL, label); - } - out.println("</tr>"); - } else { - // header - empty cell in upper left - out.format("<tr class='%s'><td class='%s'></td>%n", CSS_TABLE, CSS_LABEL); - for (String label : labels) { - out.format("<td>%s</td>", label); - } - out.format("</tr>"); - } - } - - private static int getLongestHeader(Iterable<String> labels) { - int max = 0; - for (String label : labels) { - max = Math.max(label.length(), max); - } - return max; - } - - private static void format(String format, PrintStream out, Object... args) { - String format2 = String.format(format, args); - out.println(format2); - } - - public static void printHeader(PrintStream out, CharSequence title) { - out.println(HEADER.replace("TITLE", title)); - } - - public static void printFooter(PrintStream out) { - out.println(FOOTER); - } - -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java b/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java deleted file mode 100644 index 545c1ff..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java +++ /dev/null @@ -1,387 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.cdbw; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.GaussianAccumulator; -import org.apache.mahout.clustering.OnlineGaussianAccumulator; -import org.apache.mahout.clustering.evaluation.RepresentativePointsDriver; -import org.apache.mahout.clustering.evaluation.RepresentativePointsMapper; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.common.ClassUtils; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.Vector.Element; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; - -/** - * This class calculates the CDbw metric as defined in - * http://www.db-net.aueb.gr/index.php/corporate/content/download/227/833/file/HV_poster2002.pdf - */ -public final class CDbwEvaluator { - - private static final Logger log = LoggerFactory.getLogger(CDbwEvaluator.class); - - private final Map<Integer,List<VectorWritable>> representativePoints; - private final Map<Integer,Double> stDevs = new HashMap<>(); - private final List<Cluster> clusters; - private final DistanceMeasure measure; - private Double interClusterDensity = null; - // these are symmetric so we only compute half of them - private Map<Integer,Map<Integer,Double>> minimumDistances = null; - // these are symmetric too - private Map<Integer,Map<Integer,Double>> interClusterDensities = null; - // these are symmetric too - private Map<Integer,Map<Integer,int[]>> closestRepPointIndices = null; - - /** - * For testing only - * - * @param representativePoints - * a Map<Integer,List<VectorWritable>> of representative points keyed by clusterId - * @param clusters - * a Map<Integer,Cluster> of the clusters keyed by clusterId - * @param measure - * an appropriate DistanceMeasure - */ - public CDbwEvaluator(Map<Integer,List<VectorWritable>> representativePoints, List<Cluster> clusters, - DistanceMeasure measure) { - this.representativePoints = representativePoints; - this.clusters = clusters; - this.measure = measure; - for (Integer cId : representativePoints.keySet()) { - computeStd(cId); - } - } - - /** - * Initialize a new instance from job information - * - * @param conf - * a Configuration with appropriate parameters - * @param clustersIn - * a String path to the input clusters directory - */ - public CDbwEvaluator(Configuration conf, Path clustersIn) { - measure = ClassUtils - .instantiateAs(conf.get(RepresentativePointsDriver.DISTANCE_MEASURE_KEY), DistanceMeasure.class); - representativePoints = RepresentativePointsMapper.getRepresentativePoints(conf); - clusters = loadClusters(conf, clustersIn); - for (Integer cId : representativePoints.keySet()) { - computeStd(cId); - } - } - - /** - * Load the clusters from their sequence files - * - * @param clustersIn - * a String pathname to the directory containing input cluster files - * @return a List<Cluster> of the clusters - */ - private static List<Cluster> loadClusters(Configuration conf, Path clustersIn) { - List<Cluster> clusters = new ArrayList<>(); - for (ClusterWritable clusterWritable : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST, - PathFilters.logsCRCFilter(), conf)) { - Cluster cluster = clusterWritable.getValue(); - clusters.add(cluster); - } - return clusters; - } - - /** - * Compute the standard deviation of the representative points for the given cluster. Store these in stDevs, indexed - * by cI - * - * @param cI - * a int clusterId. - */ - private void computeStd(int cI) { - List<VectorWritable> repPts = representativePoints.get(cI); - GaussianAccumulator accumulator = new OnlineGaussianAccumulator(); - for (VectorWritable vw : repPts) { - accumulator.observe(vw.get(), 1.0); - } - accumulator.compute(); - double d = accumulator.getAverageStd(); - stDevs.put(cI, d); - } - - /** - * Compute the density of points near the midpoint between the two closest points of the clusters (eqn 2) used for - * inter-cluster density calculation - * - * @param uIJ - * the Vector midpoint between the closest representative points of the clusters - * @param cI - * the int clusterId of the i-th cluster - * @param cJ - * the int clusterId of the j-th cluster - * @param avgStd - * the double average standard deviation of the two clusters - * @return a double - */ - private double density(Vector uIJ, int cI, int cJ, double avgStd) { - List<VectorWritable> repI = representativePoints.get(cI); - List<VectorWritable> repJ = representativePoints.get(cJ); - double sum = 0.0; - // count the number of representative points of the clusters which are within the - // average std of the two clusters from the midpoint uIJ (eqn 3) - for (VectorWritable vwI : repI) { - if (uIJ != null && measure.distance(uIJ, vwI.get()) <= avgStd) { - sum++; - } - } - for (VectorWritable vwJ : repJ) { - if (uIJ != null && measure.distance(uIJ, vwJ.get()) <= avgStd) { - sum++; - } - } - int nI = repI.size(); - int nJ = repJ.size(); - return sum / (nI + nJ); - } - - /** - * Compute the CDbw validity metric (eqn 8). The goal of this metric is to reward clusterings which have a high - * intraClusterDensity and also a high cluster separation. - * - * @return a double - */ - public double getCDbw() { - return intraClusterDensity() * separation(); - } - - /** - * The average density within clusters is defined as the percentage of representative points that reside in the - * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5) - * - * @return a double - */ - public double intraClusterDensity() { - double avgDensity = 0; - int count = 0; - for (Element elem : intraClusterDensities().nonZeroes()) { - double value = elem.get(); - if (!Double.isNaN(value)) { - avgDensity += value; - count++; - } - } - return avgDensity / count; - } - - /** - * This function evaluates the density of points in the regions between each clusters (eqn 1). The goal is the density - * in the area between clusters to be significant low. - * - * @return a Map<Integer,Map<Integer,Double>> of the inter-cluster densities - */ - public Map<Integer,Map<Integer,Double>> interClusterDensities() { - if (interClusterDensities != null) { - return interClusterDensities; - } - interClusterDensities = new TreeMap<>(); - // find the closest representative points between the clusters - for (int i = 0; i < clusters.size(); i++) { - int cI = clusters.get(i).getId(); - Map<Integer,Double> map = new TreeMap<>(); - interClusterDensities.put(cI, map); - for (int j = i + 1; j < clusters.size(); j++) { - int cJ = clusters.get(j).getId(); - double minDistance = minimumDistance(cI, cJ); // the distance between the closest representative points - Vector uIJ = midpointVector(cI, cJ); // the midpoint between the closest representative points - double stdSum = stDevs.get(cI) + stDevs.get(cJ); - double density = density(uIJ, cI, cJ, stdSum / 2); - double interDensity = minDistance * density / stdSum; - map.put(cJ, interDensity); - if (log.isDebugEnabled()) { - log.debug("minDistance[{},{}]={}", cI, cJ, minDistance); - log.debug("interDensity[{},{}]={}", cI, cJ, density); - log.debug("density[{},{}]={}", cI, cJ, interDensity); - } - } - } - return interClusterDensities; - } - - /** - * Calculate the separation of clusters (eqn 4) taking into account both the distances between the clusters' closest - * points and the Inter-cluster density. The goal is the distances between clusters to be high while the - * representative point density in the areas between them are low. - * - * @return a double - */ - public double separation() { - double minDistanceSum = 0; - Map<Integer,Map<Integer,Double>> distances = minimumDistances(); - for (Map<Integer,Double> map : distances.values()) { - for (Double dist : map.values()) { - if (!Double.isInfinite(dist)) { - minDistanceSum += dist * 2; // account for other half of calculated triangular minimumDistances matrix - } - } - } - return minDistanceSum / (1.0 + interClusterDensity()); - } - - /** - * This function evaluates the average density of points in the regions between clusters (eqn 1). The goal is the - * density in the area between clusters to be significant low. - * - * @return a double - */ - public double interClusterDensity() { - if (interClusterDensity != null) { - return interClusterDensity; - } - double sum = 0.0; - int count = 0; - Map<Integer,Map<Integer,Double>> distances = interClusterDensities(); - for (Map<Integer,Double> row : distances.values()) { - for (Double density : row.values()) { - if (!Double.isNaN(density)) { - sum += density; - count++; - } - } - } - log.debug("interClusterDensity={}", sum); - interClusterDensity = sum / count; - return interClusterDensity; - } - - /** - * The average density within clusters is defined as the percentage of representative points that reside in the - * neighborhood of the clusters' centers. The goal is the density within clusters to be significantly high. (eqn 5) - * - * @return a Vector of the intra-densities of each clusterId - */ - public Vector intraClusterDensities() { - Vector densities = new RandomAccessSparseVector(Integer.MAX_VALUE); - // compute the average standard deviation of the clusters - double stdev = 0.0; - for (Integer cI : representativePoints.keySet()) { - stdev += stDevs.get(cI); - } - int c = representativePoints.size(); - stdev /= c; - for (Cluster cluster : clusters) { - Integer cI = cluster.getId(); - List<VectorWritable> repPtsI = representativePoints.get(cI); - int r = repPtsI.size(); - double sumJ = 0.0; - // compute the term density (eqn 6) - for (VectorWritable pt : repPtsI) { - // compute f(x, vIJ) (eqn 7) - Vector repJ = pt.get(); - double densityIJ = measure.distance(cluster.getCenter(), repJ) <= stdev ? 1.0 : 0.0; - // accumulate sumJ - sumJ += densityIJ / stdev; - } - densities.set(cI, sumJ / r); - } - return densities; - } - - /** - * Calculate and cache the distances between the clusters' closest representative points. Also cache the indices of - * the closest representative points used for later use - * - * @return a Map<Integer,Vector> of the closest distances, keyed by clusterId - */ - private Map<Integer,Map<Integer,Double>> minimumDistances() { - if (minimumDistances != null) { - return minimumDistances; - } - minimumDistances = new TreeMap<>(); - closestRepPointIndices = new TreeMap<>(); - for (int i = 0; i < clusters.size(); i++) { - Integer cI = clusters.get(i).getId(); - Map<Integer,Double> map = new TreeMap<>(); - Map<Integer,int[]> treeMap = new TreeMap<>(); - closestRepPointIndices.put(cI, treeMap); - minimumDistances.put(cI, map); - List<VectorWritable> closRepI = representativePoints.get(cI); - for (int j = i + 1; j < clusters.size(); j++) { - // find min{d(closRepI, closRepJ)} - Integer cJ = clusters.get(j).getId(); - List<VectorWritable> closRepJ = representativePoints.get(cJ); - double minDistance = Double.MAX_VALUE; - int[] midPointIndices = null; - for (int xI = 0; xI < closRepI.size(); xI++) { - VectorWritable aRepI = closRepI.get(xI); - for (int xJ = 0; xJ < closRepJ.size(); xJ++) { - VectorWritable aRepJ = closRepJ.get(xJ); - double distance = measure.distance(aRepI.get(), aRepJ.get()); - if (distance < minDistance) { - minDistance = distance; - midPointIndices = new int[] {xI, xJ}; - } - } - } - map.put(cJ, minDistance); - treeMap.put(cJ, midPointIndices); - } - } - return minimumDistances; - } - - private double minimumDistance(int cI, int cJ) { - Map<Integer,Double> distances = minimumDistances().get(cI); - if (distances != null) { - return distances.get(cJ); - } else { - return minimumDistances().get(cJ).get(cI); - } - } - - private Vector midpointVector(int cI, int cJ) { - Map<Integer,Double> distances = minimumDistances().get(cI); - if (distances != null) { - int[] ks = closestRepPointIndices.get(cI).get(cJ); - if (ks == null) { - return null; - } - return representativePoints.get(cI).get(ks[0]).get().plus(representativePoints.get(cJ).get(ks[1]).get()) - .divide(2); - } else { - int[] ks = closestRepPointIndices.get(cJ).get(cI); - if (ks == null) { - return null; - } - return representativePoints.get(cJ).get(ks[1]).get().plus(representativePoints.get(cI).get(ks[0]).get()) - .divide(2); - } - - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java b/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java deleted file mode 100644 index 6a2b376..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputDriver.java +++ /dev/null @@ -1,114 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.conversion; - -import java.io.IOException; - -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.OptionException; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Job; -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; -import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; -import org.apache.mahout.common.CommandLineUtil; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * This class converts text files containing space-delimited floating point numbers into - * Mahout sequence files of VectorWritable suitable for input to the clustering jobs in - * particular, and any Mahout job requiring this input in general. - * - */ -public final class InputDriver { - - private static final Logger log = LoggerFactory.getLogger(InputDriver.class); - - private InputDriver() { - } - - public static void main(String[] args) throws IOException, InterruptedException, ClassNotFoundException { - DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); - ArgumentBuilder abuilder = new ArgumentBuilder(); - GroupBuilder gbuilder = new GroupBuilder(); - - Option inputOpt = DefaultOptionCreator.inputOption().withRequired(false).create(); - Option outputOpt = DefaultOptionCreator.outputOption().withRequired(false).create(); - Option vectorOpt = obuilder.withLongName("vector").withRequired(false).withArgument( - abuilder.withName("v").withMinimum(1).withMaximum(1).create()).withDescription( - "The vector implementation to use.").withShortName("v").create(); - - Option helpOpt = DefaultOptionCreator.helpOption(); - - Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption( - vectorOpt).withOption(helpOpt).create(); - - try { - Parser parser = new Parser(); - parser.setGroup(group); - CommandLine cmdLine = parser.parse(args); - if (cmdLine.hasOption(helpOpt)) { - CommandLineUtil.printHelp(group); - return; - } - - Path input = new Path(cmdLine.getValue(inputOpt, "testdata").toString()); - Path output = new Path(cmdLine.getValue(outputOpt, "output").toString()); - String vectorClassName = cmdLine.getValue(vectorOpt, - "org.apache.mahout.math.RandomAccessSparseVector").toString(); - runJob(input, output, vectorClassName); - } catch (OptionException e) { - log.error("Exception parsing command line: ", e); - CommandLineUtil.printHelp(group); - } - } - - public static void runJob(Path input, Path output, String vectorClassName) - throws IOException, InterruptedException, ClassNotFoundException { - Configuration conf = new Configuration(); - conf.set("vector.implementation.class.name", vectorClassName); - Job job = new Job(conf, "Input Driver running over input: " + input); - - job.setOutputKeyClass(Text.class); - job.setOutputValueClass(VectorWritable.class); - job.setOutputFormatClass(SequenceFileOutputFormat.class); - job.setMapperClass(InputMapper.class); - job.setNumReduceTasks(0); - job.setJarByClass(InputDriver.class); - - FileInputFormat.addInputPath(job, input); - FileOutputFormat.setOutputPath(job, output); - - boolean succeeded = job.waitForCompletion(true); - if (!succeeded) { - throw new IllegalStateException("Job failed!"); - } - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java b/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java deleted file mode 100644 index e4c72c6..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/conversion/InputMapper.java +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.conversion; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Mapper; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.VectorWritable; - -import java.io.IOException; -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.regex.Pattern; - -public class InputMapper extends Mapper<LongWritable, Text, Text, VectorWritable> { - - private static final Pattern SPACE = Pattern.compile(" "); - - private Constructor<?> constructor; - - @Override - protected void map(LongWritable key, Text values, Context context) throws IOException, InterruptedException { - - String[] numbers = SPACE.split(values.toString()); - // sometimes there are multiple separator spaces - Collection<Double> doubles = new ArrayList<>(); - for (String value : numbers) { - if (!value.isEmpty()) { - doubles.add(Double.valueOf(value)); - } - } - // ignore empty lines in data file - if (!doubles.isEmpty()) { - try { - Vector result = (Vector) constructor.newInstance(doubles.size()); - int index = 0; - for (Double d : doubles) { - result.set(index++, d); - } - VectorWritable vectorWritable = new VectorWritable(result); - context.write(new Text(String.valueOf(index)), vectorWritable); - - } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { - throw new IllegalStateException(e); - } - } - } - - @Override - protected void setup(Context context) throws IOException, InterruptedException { - super.setup(context); - Configuration conf = context.getConfiguration(); - String vectorImplClassName = conf.get("vector.implementation.class.name"); - try { - Class<? extends Vector> outputClass = conf.getClassByName(vectorImplClassName).asSubclass(Vector.class); - constructor = outputClass.getConstructor(int.class); - } catch (NoSuchMethodException | ClassNotFoundException e) { - throw new IllegalStateException(e); - } - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java b/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java deleted file mode 100644 index 757f38c..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/evaluation/ClusterEvaluator.java +++ /dev/null @@ -1,196 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.evaluation; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.common.ClassUtils; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.Vector.Element; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; - -public class ClusterEvaluator { - - private static final Logger log = LoggerFactory.getLogger(ClusterEvaluator.class); - - private final Map<Integer,List<VectorWritable>> representativePoints; - - private final List<Cluster> clusters; - - private final DistanceMeasure measure; - - /** - * For testing only - * - * @param representativePoints - * a Map<Integer,List<VectorWritable>> of representative points keyed by clusterId - * @param clusters - * a Map<Integer,Cluster> of the clusters keyed by clusterId - * @param measure - * an appropriate DistanceMeasure - */ - public ClusterEvaluator(Map<Integer,List<VectorWritable>> representativePoints, List<Cluster> clusters, - DistanceMeasure measure) { - this.representativePoints = representativePoints; - this.clusters = clusters; - this.measure = measure; - } - - /** - * Initialize a new instance from job information - * - * @param conf - * a Configuration with appropriate parameters - * @param clustersIn - * a String path to the input clusters directory - */ - public ClusterEvaluator(Configuration conf, Path clustersIn) { - measure = ClassUtils - .instantiateAs(conf.get(RepresentativePointsDriver.DISTANCE_MEASURE_KEY), DistanceMeasure.class); - representativePoints = RepresentativePointsMapper.getRepresentativePoints(conf); - clusters = loadClusters(conf, clustersIn); - } - - /** - * Load the clusters from their sequence files - * - * @param clustersIn - * a String pathname to the directory containing input cluster files - * @return a List<Cluster> of the clusters - */ - private static List<Cluster> loadClusters(Configuration conf, Path clustersIn) { - List<Cluster> clusters = new ArrayList<>(); - for (ClusterWritable clusterWritable : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST, - PathFilters.logsCRCFilter(), conf)) { - Cluster cluster = clusterWritable.getValue(); - clusters.add(cluster); - } - return clusters; - } - - /** - * Computes the inter-cluster density as defined in "Mahout In Action" - * - * @return the interClusterDensity - */ - public double interClusterDensity() { - double max = Double.NEGATIVE_INFINITY; - double min = Double.POSITIVE_INFINITY; - double sum = 0; - int count = 0; - Map<Integer,Vector> distances = interClusterDistances(); - for (Vector row : distances.values()) { - for (Element element : row.nonZeroes()) { - double d = element.get(); - min = Math.min(d, min); - max = Math.max(d, max); - sum += d; - count++; - } - } - double density = (sum / count - min) / (max - min); - log.info("Scaled Inter-Cluster Density = {}", density); - return density; - } - - /** - * Computes the inter-cluster distances - * - * @return a Map<Integer, Vector> - */ - public Map<Integer,Vector> interClusterDistances() { - Map<Integer,Vector> distances = new TreeMap<>(); - for (int i = 0; i < clusters.size(); i++) { - Cluster clusterI = clusters.get(i); - RandomAccessSparseVector row = new RandomAccessSparseVector(Integer.MAX_VALUE); - distances.put(clusterI.getId(), row); - for (int j = i + 1; j < clusters.size(); j++) { - Cluster clusterJ = clusters.get(j); - double d = measure.distance(clusterI.getCenter(), clusterJ.getCenter()); - row.set(clusterJ.getId(), d); - } - } - return distances; - } - - /** - * Computes the average intra-cluster density as the average of each cluster's intra-cluster density - * - * @return the average intraClusterDensity - */ - public double intraClusterDensity() { - double avgDensity = 0; - int count = 0; - for (Element elem : intraClusterDensities().nonZeroes()) { - double value = elem.get(); - if (!Double.isNaN(value)) { - avgDensity += value; - count++; - } - } - avgDensity = clusters.isEmpty() ? 0 : avgDensity / count; - log.info("Average Intra-Cluster Density = {}", avgDensity); - return avgDensity; - } - - /** - * Computes the intra-cluster densities for all clusters as the average distance of the representative points from - * each other - * - * @return a Vector of the intraClusterDensity of the representativePoints by clusterId - */ - public Vector intraClusterDensities() { - Vector densities = new RandomAccessSparseVector(Integer.MAX_VALUE); - for (Cluster cluster : clusters) { - int count = 0; - double max = Double.NEGATIVE_INFINITY; - double min = Double.POSITIVE_INFINITY; - double sum = 0; - List<VectorWritable> repPoints = representativePoints.get(cluster.getId()); - for (int i = 0; i < repPoints.size(); i++) { - for (int j = i + 1; j < repPoints.size(); j++) { - Vector v1 = repPoints.get(i).get(); - Vector v2 = repPoints.get(j).get(); - double d = measure.distance(v1, v2); - min = Math.min(d, min); - max = Math.max(d, max); - sum += d; - count++; - } - } - double density = (sum / count - min) / (max - min); - densities.set(cluster.getId(), density); - log.info("Intra-Cluster Density[{}] = {}", cluster.getId(), density); - } - return densities; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java b/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java deleted file mode 100644 index 2fe37ef..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsDriver.java +++ /dev/null @@ -1,243 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.evaluation; - -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileStatus; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.SequenceFile; -import org.apache.hadoop.mapreduce.Job; -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; -import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; -import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; -import org.apache.hadoop.util.ToolRunner; -import org.apache.mahout.clustering.AbstractCluster; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.classify.WeightedVectorWritable; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.ClassUtils; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public final class RepresentativePointsDriver extends AbstractJob { - - public static final String STATE_IN_KEY = "org.apache.mahout.clustering.stateIn"; - - public static final String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.measure"; - - private static final Logger log = LoggerFactory.getLogger(RepresentativePointsDriver.class); - - private RepresentativePointsDriver() {} - - public static void main(String[] args) throws Exception { - ToolRunner.run(new Configuration(), new RepresentativePointsDriver(), args); - } - - @Override - public int run(String[] args) throws ClassNotFoundException, IOException, InterruptedException { - addInputOption(); - addOutputOption(); - addOption("clusteredPoints", "cp", "The path to the clustered points", true); - addOption(DefaultOptionCreator.distanceMeasureOption().create()); - addOption(DefaultOptionCreator.maxIterationsOption().create()); - addOption(DefaultOptionCreator.methodOption().create()); - if (parseArguments(args) == null) { - return -1; - } - - Path input = getInputPath(); - Path output = getOutputPath(); - String distanceMeasureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); - int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); - boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase( - DefaultOptionCreator.SEQUENTIAL_METHOD); - DistanceMeasure measure = ClassUtils.instantiateAs(distanceMeasureClass, DistanceMeasure.class); - Path clusteredPoints = new Path(getOption("clusteredPoints")); - run(getConf(), input, clusteredPoints, output, measure, maxIterations, runSequential); - return 0; - } - - /** - * Utility to print out representative points - * - * @param output - * the Path to the directory containing representativePoints-i folders - * @param numIterations - * the int number of iterations to print - */ - public static void printRepresentativePoints(Path output, int numIterations) { - for (int i = 0; i <= numIterations; i++) { - Path out = new Path(output, "representativePoints-" + i); - System.out.println("Representative Points for iteration " + i); - Configuration conf = new Configuration(); - for (Pair<IntWritable,VectorWritable> record : new SequenceFileDirIterable<IntWritable,VectorWritable>(out, - PathType.LIST, PathFilters.logsCRCFilter(), null, true, conf)) { - System.out.println("\tC-" + record.getFirst().get() + ": " - + AbstractCluster.formatVector(record.getSecond().get(), null)); - } - } - } - - public static void run(Configuration conf, Path clustersIn, Path clusteredPointsIn, Path output, - DistanceMeasure measure, int numIterations, boolean runSequential) throws IOException, InterruptedException, - ClassNotFoundException { - Path stateIn = new Path(output, "representativePoints-0"); - writeInitialState(stateIn, clustersIn); - - for (int iteration = 0; iteration < numIterations; iteration++) { - log.info("Representative Points Iteration {}", iteration); - // point the output to a new directory per iteration - Path stateOut = new Path(output, "representativePoints-" + (iteration + 1)); - runIteration(conf, clusteredPointsIn, stateIn, stateOut, measure, runSequential); - // now point the input to the old output directory - stateIn = stateOut; - } - - conf.set(STATE_IN_KEY, stateIn.toString()); - conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName()); - } - - private static void writeInitialState(Path output, Path clustersIn) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(output.toUri(), conf); - for (FileStatus dir : fs.globStatus(clustersIn)) { - Path inPath = dir.getPath(); - for (FileStatus part : fs.listStatus(inPath, PathFilters.logsCRCFilter())) { - Path inPart = part.getPath(); - Path path = new Path(output, inPart.getName()); - try (SequenceFile.Writer writer = - new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class)){ - for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(inPart, true, conf)) { - Cluster cluster = clusterWritable.getValue(); - if (log.isDebugEnabled()) { - log.debug("C-{}: {}", cluster.getId(), AbstractCluster.formatVector(cluster.getCenter(), null)); - } - writer.append(new IntWritable(cluster.getId()), new VectorWritable(cluster.getCenter())); - } - } - } - } - } - - private static void runIteration(Configuration conf, Path clusteredPointsIn, Path stateIn, Path stateOut, - DistanceMeasure measure, boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException { - if (runSequential) { - runIterationSeq(conf, clusteredPointsIn, stateIn, stateOut, measure); - } else { - runIterationMR(conf, clusteredPointsIn, stateIn, stateOut, measure); - } - } - - /** - * Run the job using supplied arguments as a sequential process - * - * @param conf - * the Configuration to use - * @param clusteredPointsIn - * the directory pathname for input points - * @param stateIn - * the directory pathname for input state - * @param stateOut - * the directory pathname for output state - * @param measure - * the DistanceMeasure to use - */ - private static void runIterationSeq(Configuration conf, Path clusteredPointsIn, Path stateIn, Path stateOut, - DistanceMeasure measure) throws IOException { - - Map<Integer,List<VectorWritable>> repPoints = RepresentativePointsMapper.getRepresentativePoints(conf, stateIn); - Map<Integer,WeightedVectorWritable> mostDistantPoints = new HashMap<>(); - FileSystem fs = FileSystem.get(clusteredPointsIn.toUri(), conf); - for (Pair<IntWritable,WeightedVectorWritable> record - : new SequenceFileDirIterable<IntWritable,WeightedVectorWritable>(clusteredPointsIn, PathType.LIST, - PathFilters.logsCRCFilter(), null, true, conf)) { - RepresentativePointsMapper.mapPoint(record.getFirst(), record.getSecond(), measure, repPoints, mostDistantPoints); - } - int part = 0; - try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(stateOut, "part-m-" + part++), - IntWritable.class, VectorWritable.class)){ - for (Entry<Integer,List<VectorWritable>> entry : repPoints.entrySet()) { - for (VectorWritable vw : entry.getValue()) { - writer.append(new IntWritable(entry.getKey()), vw); - } - } - } - try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, new Path(stateOut, "part-m-" + part++), - IntWritable.class, VectorWritable.class)){ - for (Map.Entry<Integer,WeightedVectorWritable> entry : mostDistantPoints.entrySet()) { - writer.append(new IntWritable(entry.getKey()), new VectorWritable(entry.getValue().getVector())); - } - } - } - - /** - * Run the job using supplied arguments as a Map/Reduce process - * - * @param conf - * the Configuration to use - * @param input - * the directory pathname for input points - * @param stateIn - * the directory pathname for input state - * @param stateOut - * the directory pathname for output state - * @param measure - * the DistanceMeasure to use - */ - private static void runIterationMR(Configuration conf, Path input, Path stateIn, Path stateOut, - DistanceMeasure measure) throws IOException, InterruptedException, ClassNotFoundException { - conf.set(STATE_IN_KEY, stateIn.toString()); - conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName()); - Job job = new Job(conf, "Representative Points Driver running over input: " + input); - job.setJarByClass(RepresentativePointsDriver.class); - job.setOutputKeyClass(IntWritable.class); - job.setOutputValueClass(VectorWritable.class); - job.setMapOutputKeyClass(IntWritable.class); - job.setMapOutputValueClass(WeightedVectorWritable.class); - - FileInputFormat.setInputPaths(job, input); - FileOutputFormat.setOutputPath(job, stateOut); - - job.setMapperClass(RepresentativePointsMapper.class); - job.setReducerClass(RepresentativePointsReducer.class); - job.setInputFormatClass(SequenceFileInputFormat.class); - job.setOutputFormatClass(SequenceFileOutputFormat.class); - - boolean succeeded = job.waitForCompletion(true); - if (!succeeded) { - throw new IllegalStateException("Job failed!"); - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsMapper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsMapper.java b/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsMapper.java deleted file mode 100644 index 0ae79ad..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsMapper.java +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.evaluation; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.mapreduce.Mapper; -import org.apache.mahout.clustering.classify.WeightedVectorWritable; -import org.apache.mahout.common.ClassUtils; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.distance.EuclideanDistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; -import org.apache.mahout.math.VectorWritable; - -public class RepresentativePointsMapper - extends Mapper<IntWritable, WeightedVectorWritable, IntWritable, WeightedVectorWritable> { - - private Map<Integer, List<VectorWritable>> representativePoints; - private final Map<Integer, WeightedVectorWritable> mostDistantPoints = new HashMap<>(); - private DistanceMeasure measure = new EuclideanDistanceMeasure(); - - @Override - protected void cleanup(Context context) throws IOException, InterruptedException { - for (Map.Entry<Integer, WeightedVectorWritable> entry : mostDistantPoints.entrySet()) { - context.write(new IntWritable(entry.getKey()), entry.getValue()); - } - super.cleanup(context); - } - - @Override - protected void map(IntWritable clusterId, WeightedVectorWritable point, Context context) - throws IOException, InterruptedException { - mapPoint(clusterId, point, measure, representativePoints, mostDistantPoints); - } - - public static void mapPoint(IntWritable clusterId, - WeightedVectorWritable point, - DistanceMeasure measure, - Map<Integer, List<VectorWritable>> representativePoints, - Map<Integer, WeightedVectorWritable> mostDistantPoints) { - int key = clusterId.get(); - WeightedVectorWritable currentMDP = mostDistantPoints.get(key); - - List<VectorWritable> repPoints = representativePoints.get(key); - double totalDistance = 0.0; - if (repPoints != null) { - for (VectorWritable refPoint : repPoints) { - totalDistance += measure.distance(refPoint.get(), point.getVector()); - } - } - if (currentMDP == null || currentMDP.getWeight() < totalDistance) { - mostDistantPoints.put(key, new WeightedVectorWritable(totalDistance, point.getVector().clone())); - } - } - - @Override - protected void setup(Context context) throws IOException, InterruptedException { - super.setup(context); - Configuration conf = context.getConfiguration(); - measure = - ClassUtils.instantiateAs(conf.get(RepresentativePointsDriver.DISTANCE_MEASURE_KEY), DistanceMeasure.class); - representativePoints = getRepresentativePoints(conf); - } - - public void configure(Map<Integer, List<VectorWritable>> referencePoints, DistanceMeasure measure) { - this.representativePoints = referencePoints; - this.measure = measure; - } - - public static Map<Integer, List<VectorWritable>> getRepresentativePoints(Configuration conf) { - String statePath = conf.get(RepresentativePointsDriver.STATE_IN_KEY); - return getRepresentativePoints(conf, new Path(statePath)); - } - - public static Map<Integer, List<VectorWritable>> getRepresentativePoints(Configuration conf, Path statePath) { - Map<Integer, List<VectorWritable>> representativePoints = new HashMap<>(); - for (Pair<IntWritable,VectorWritable> record - : new SequenceFileDirIterable<IntWritable,VectorWritable>(statePath, - PathType.LIST, - PathFilters.logsCRCFilter(), - conf)) { - int keyValue = record.getFirst().get(); - List<VectorWritable> repPoints = representativePoints.get(keyValue); - if (repPoints == null) { - repPoints = new ArrayList<>(); - representativePoints.put(keyValue, repPoints); - } - repPoints.add(record.getSecond()); - } - return representativePoints; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsReducer.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsReducer.java b/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsReducer.java deleted file mode 100644 index 27ca861..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/evaluation/RepresentativePointsReducer.java +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.evaluation; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.mapreduce.Reducer; -import org.apache.mahout.clustering.classify.WeightedVectorWritable; -import org.apache.mahout.math.VectorWritable; - -public class RepresentativePointsReducer - extends Reducer<IntWritable, WeightedVectorWritable, IntWritable, VectorWritable> { - - private Map<Integer, List<VectorWritable>> representativePoints; - - @Override - protected void cleanup(Context context) throws IOException, InterruptedException { - for (Map.Entry<Integer, List<VectorWritable>> entry : representativePoints.entrySet()) { - IntWritable iw = new IntWritable(entry.getKey()); - for (VectorWritable vw : entry.getValue()) { - context.write(iw, vw); - } - } - super.cleanup(context); - } - - @Override - protected void reduce(IntWritable key, Iterable<WeightedVectorWritable> values, Context context) - throws IOException, InterruptedException { - // find the most distant point - WeightedVectorWritable mdp = null; - for (WeightedVectorWritable dpw : values) { - if (mdp == null || mdp.getWeight() < dpw.getWeight()) { - mdp = new WeightedVectorWritable(dpw.getWeight(), dpw.getVector()); - } - } - context.write(new IntWritable(key.get()), new VectorWritable(mdp.getVector())); - } - - @Override - protected void setup(Context context) throws IOException, InterruptedException { - super.setup(context); - Configuration conf = context.getConfiguration(); - representativePoints = RepresentativePointsMapper.getRepresentativePoints(conf); - } - - public void configure(Map<Integer, List<VectorWritable>> representativePoints) { - this.representativePoints = representativePoints; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java b/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java deleted file mode 100644 index 392909e..0000000 --- a/integration/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java +++ /dev/null @@ -1,229 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.lda; - -import com.google.common.io.Closeables; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.io.Writer; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.PriorityQueue; -import java.util.Queue; -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.OptionException; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.DoubleWritable; -import org.apache.mahout.common.CommandLineUtil; -import org.apache.mahout.common.IntPairWritable; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; -import org.apache.mahout.utils.vectors.VectorHelper; - -/** - * Class to print out the top K words for each topic. - */ -public final class LDAPrintTopics { - - private LDAPrintTopics() { } - - // Expands the queue list to have a Queue for topic K - private static void ensureQueueSize(Collection<Queue<Pair<String,Double>>> queues, int k) { - for (int i = queues.size(); i <= k; ++i) { - queues.add(new PriorityQueue<Pair<String,Double>>()); - } - } - - public static void main(String[] args) throws Exception { - DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); - ArgumentBuilder abuilder = new ArgumentBuilder(); - GroupBuilder gbuilder = new GroupBuilder(); - - Option inputOpt = DefaultOptionCreator.inputOption().create(); - - Option dictOpt = obuilder.withLongName("dict").withRequired(true).withArgument( - abuilder.withName("dict").withMinimum(1).withMaximum(1).create()).withDescription( - "Dictionary to read in, in the same format as one created by " - + "org.apache.mahout.utils.vectors.lucene.Driver").withShortName("d").create(); - - Option outOpt = DefaultOptionCreator.outputOption().create(); - - Option wordOpt = obuilder.withLongName("words").withRequired(false).withArgument( - abuilder.withName("words").withMinimum(0).withMaximum(1).withDefault("20").create()).withDescription( - "Number of words to print").withShortName("w").create(); - Option dictTypeOpt = obuilder.withLongName("dictionaryType").withRequired(false).withArgument( - abuilder.withName("dictionaryType").withMinimum(1).withMaximum(1).create()).withDescription( - "The dictionary file type (text|sequencefile)").withShortName("dt").create(); - Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") - .create(); - - Group group = gbuilder.withName("Options").withOption(dictOpt).withOption(outOpt).withOption(wordOpt) - .withOption(inputOpt).withOption(dictTypeOpt).create(); - try { - Parser parser = new Parser(); - parser.setGroup(group); - CommandLine cmdLine = parser.parse(args); - - if (cmdLine.hasOption(helpOpt)) { - CommandLineUtil.printHelp(group); - return; - } - - String input = cmdLine.getValue(inputOpt).toString(); - String dictFile = cmdLine.getValue(dictOpt).toString(); - int numWords = 20; - if (cmdLine.hasOption(wordOpt)) { - numWords = Integer.parseInt(cmdLine.getValue(wordOpt).toString()); - } - Configuration config = new Configuration(); - - String dictionaryType = "text"; - if (cmdLine.hasOption(dictTypeOpt)) { - dictionaryType = cmdLine.getValue(dictTypeOpt).toString(); - } - - List<String> wordList; - if ("text".equals(dictionaryType)) { - wordList = Arrays.asList(VectorHelper.loadTermDictionary(new File(dictFile))); - } else if ("sequencefile".equals(dictionaryType)) { - wordList = Arrays.asList(VectorHelper.loadTermDictionary(config, dictFile)); - } else { - throw new IllegalArgumentException("Invalid dictionary format"); - } - - List<Queue<Pair<String,Double>>> topWords = topWordsForTopics(input, config, wordList, numWords); - - File output = null; - if (cmdLine.hasOption(outOpt)) { - output = new File(cmdLine.getValue(outOpt).toString()); - if (!output.exists() && !output.mkdirs()) { - throw new IOException("Could not create directory: " + output); - } - } - printTopWords(topWords, output); - } catch (OptionException e) { - CommandLineUtil.printHelp(group); - throw e; - } - } - - // Adds the word if the queue is below capacity, or the score is high enough - private static void maybeEnqueue(Queue<Pair<String,Double>> q, String word, double score, int numWordsToPrint) { - if (q.size() >= numWordsToPrint && score > q.peek().getSecond()) { - q.poll(); - } - if (q.size() < numWordsToPrint) { - q.add(new Pair<>(word, score)); - } - } - - private static void printTopWords(List<Queue<Pair<String,Double>>> topWords, File outputDir) - throws IOException { - for (int i = 0; i < topWords.size(); ++i) { - Collection<Pair<String,Double>> topK = topWords.get(i); - Writer out = null; - boolean printingToSystemOut = false; - try { - if (outputDir != null) { - out = new OutputStreamWriter(new FileOutputStream(new File(outputDir, "topic_" + i)), Charsets.UTF_8); - } else { - out = new OutputStreamWriter(System.out, Charsets.UTF_8); - printingToSystemOut = true; - out.write("Topic " + i); - out.write('\n'); - out.write("==========="); - out.write('\n'); - } - List<Pair<String,Double>> topKasList = new ArrayList<>(topK.size()); - for (Pair<String,Double> wordWithScore : topK) { - topKasList.add(wordWithScore); - } - Collections.sort(topKasList, new Comparator<Pair<String,Double>>() { - @Override - public int compare(Pair<String,Double> pair1, Pair<String,Double> pair2) { - return pair2.getSecond().compareTo(pair1.getSecond()); - } - }); - for (Pair<String,Double> wordWithScore : topKasList) { - out.write(wordWithScore.getFirst() + " [p(" + wordWithScore.getFirst() + "|topic_" + i + ") = " - + wordWithScore.getSecond()); - out.write('\n'); - } - } finally { - if (!printingToSystemOut) { - Closeables.close(out, false); - } else { - out.flush(); - } - } - } - } - - private static List<Queue<Pair<String,Double>>> topWordsForTopics(String dir, - Configuration job, - List<String> wordList, - int numWordsToPrint) { - List<Queue<Pair<String,Double>>> queues = new ArrayList<>(); - Map<Integer,Double> expSums = new HashMap<>(); - for (Pair<IntPairWritable,DoubleWritable> record - : new SequenceFileDirIterable<IntPairWritable, DoubleWritable>( - new Path(dir, "part-*"), PathType.GLOB, null, null, true, job)) { - IntPairWritable key = record.getFirst(); - int topic = key.getFirst(); - int word = key.getSecond(); - ensureQueueSize(queues, topic); - if (word >= 0 && topic >= 0) { - double score = record.getSecond().get(); - if (expSums.get(topic) == null) { - expSums.put(topic, 0.0); - } - expSums.put(topic, expSums.get(topic) + Math.exp(score)); - String realWord = wordList.get(word); - maybeEnqueue(queues.get(topic), realWord, score, numWordsToPrint); - } - } - for (int i = 0; i < queues.size(); i++) { - Queue<Pair<String,Double>> queue = queues.get(i); - Queue<Pair<String,Double>> newQueue = new PriorityQueue<>(queue.size()); - double norm = expSums.get(i); - for (Pair<String,Double> pair : queue) { - newQueue.add(new Pair<>(pair.getFirst(), Math.exp(pair.getSecond()) / norm)); - } - queues.set(i, newQueue); - } - return queues; - } -}
