http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java new file mode 100644 index 0000000..f4b8bcb --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.io.Resources; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +/** + * Train a logistic regression for the examples from Chapter 13 of Mahout in Action + */ +public final class TrainLogistic { + + private static String inputFile; + private static String outputFile; + private static LogisticModelParameters lmp; + private static int passes; + private static boolean scores; + private static OnlineLogisticRegression model; + + private TrainLogistic() { + } + + public static void main(String[] args) throws Exception { + mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + + static void mainToOutput(String[] args, PrintWriter output) throws Exception { + if (parseArgs(args)) { + double logPEstimate = 0; + int samples = 0; + + CsvRecordFactory csv = lmp.getCsvRecordFactory(); + OnlineLogisticRegression lr = lmp.createRegression(); + for (int pass = 0; pass < passes; pass++) { + try (BufferedReader in = open(inputFile)) { + // read variable names + csv.firstLine(in.readLine()); + + String line = in.readLine(); + while (line != null) { + // for each new line, get target and predictors + Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); + int targetValue = csv.processLine(line, input); + + // check performance while this is still news + double logP = lr.logLikelihood(targetValue, input); + if (!Double.isInfinite(logP)) { + if (samples < 20) { + logPEstimate = (samples * logPEstimate + logP) / (samples + 1); + } else { + logPEstimate = 0.95 * logPEstimate + 0.05 * logP; + } + samples++; + } + double p = lr.classifyScalar(input); + if (scores) { + output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", + samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); + } + + // now update model + lr.train(targetValue, input); + + line = in.readLine(); + } + } + } + + try (OutputStream modelOutput = new FileOutputStream(outputFile)) { + lmp.saveTo(modelOutput); + } + + output.println(lmp.getNumFeatures()); + output.println(lmp.getTargetVariable() + " ~ "); + String sep = ""; + for (String v : csv.getTraceDictionary().keySet()) { + double weight = predictorWeight(lr, 0, csv, v); + if (weight != 0) { + output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); + sep = " + "; + } + } + output.printf("%n"); + model = lr; + for (int row = 0; row < lr.getBeta().numRows(); row++) { + for (String key : csv.getTraceDictionary().keySet()) { + double weight = predictorWeight(lr, row, csv, key); + if (weight != 0) { + output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); + } + } + for (int column = 0; column < lr.getBeta().numCols(); column++) { + output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); + } + output.println(); + } + } + } + + private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) { + double weight = 0; + for (Integer column : csv.getTraceDictionary().get(predictor)) { + weight += lr.getBeta().get(row, column); + } + return weight; + } + + private static boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help").withDescription("print this list").create(); + + Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create(); + Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFile = builder.withLongName("input") + .withRequired(true) + .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) + .withDescription("where to get training data") + .create(); + + Option outputFile = builder.withLongName("output") + .withRequired(true) + .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) + .withDescription("where to get training data") + .create(); + + Option predictors = builder.withLongName("predictors") + .withRequired(true) + .withArgument(argumentBuilder.withName("p").create()) + .withDescription("a list of predictor variables") + .create(); + + Option types = builder.withLongName("types") + .withRequired(true) + .withArgument(argumentBuilder.withName("t").create()) + .withDescription("a list of predictor variable types (numeric, word, or text)") + .create(); + + Option target = builder.withLongName("target") + .withRequired(true) + .withArgument(argumentBuilder.withName("target").withMaximum(1).create()) + .withDescription("the name of the target variable") + .create(); + + Option features = builder.withLongName("features") + .withArgument( + argumentBuilder.withName("numFeatures") + .withDefault("1000") + .withMaximum(1).create()) + .withDescription("the number of internal hashed features to use") + .create(); + + Option passes = builder.withLongName("passes") + .withArgument( + argumentBuilder.withName("passes") + .withDefault("2") + .withMaximum(1).create()) + .withDescription("the number of times to pass over the input data") + .create(); + + Option lambda = builder.withLongName("lambda") + .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create()) + .withDescription("the amount of coefficient decay to use") + .create(); + + Option rate = builder.withLongName("rate") + .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create()) + .withDescription("the learning rate") + .create(); + + Option noBias = builder.withLongName("noBias") + .withDescription("don't include a bias term") + .create(); + + Option targetCategories = builder.withLongName("categories") + .withRequired(true) + .withArgument(argumentBuilder.withName("number").withMaximum(1).create()) + .withDescription("the number of target categories to be considered") + .create(); + + Group normalArgs = new GroupBuilder() + .withOption(help) + .withOption(quiet) + .withOption(inputFile) + .withOption(outputFile) + .withOption(target) + .withOption(targetCategories) + .withOption(predictors) + .withOption(types) + .withOption(passes) + .withOption(lambda) + .withOption(rate) + .withOption(noBias) + .withOption(features) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile); + TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile); + + List<String> typeList = new ArrayList<>(); + for (Object x : cmdLine.getValues(types)) { + typeList.add(x.toString()); + } + + List<String> predictorList = new ArrayList<>(); + for (Object x : cmdLine.getValues(predictors)) { + predictorList.add(x.toString()); + } + + lmp = new LogisticModelParameters(); + lmp.setTargetVariable(getStringArgument(cmdLine, target)); + lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories)); + lmp.setNumFeatures(getIntegerArgument(cmdLine, features)); + lmp.setUseBias(!getBooleanArgument(cmdLine, noBias)); + lmp.setTypeMap(predictorList, typeList); + + lmp.setLambda(getDoubleArgument(cmdLine, lambda)); + lmp.setLearningRate(getDoubleArgument(cmdLine, rate)); + + TrainLogistic.scores = getBooleanArgument(cmdLine, scores); + TrainLogistic.passes = getIntegerArgument(cmdLine, passes); + + return true; + } + + private static String getStringArgument(CommandLine cmdLine, Option inputFile) { + return (String) cmdLine.getValue(inputFile); + } + + private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { + return cmdLine.hasOption(option); + } + + private static int getIntegerArgument(CommandLine cmdLine, Option features) { + return Integer.parseInt((String) cmdLine.getValue(features)); + } + + private static double getDoubleArgument(CommandLine cmdLine, Option op) { + return Double.parseDouble((String) cmdLine.getValue(op)); + } + + public static OnlineLogisticRegression getModel() { + return model; + } + + public static LogisticModelParameters getParameters() { + return lmp; + } + + static BufferedReader open(String inputFile) throws IOException { + InputStream in; + try { + in = Resources.getResource(inputFile).openStream(); + } catch (IllegalArgumentException e) { + in = new FileInputStream(new File(inputFile)); + } + return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8)); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java new file mode 100644 index 0000000..632b32c --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; +import com.google.common.collect.Ordering; +import org.apache.mahout.classifier.NewsgroupHelper; +import org.apache.mahout.ep.State; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.Dictionary; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Reads and trains an adaptive logistic regression model on the 20 newsgroups data. + * The first command line argument gives the path of the directory holding the training + * data. The optional second argument, leakType, defines which classes of features to use. + * Importantly, leakType controls whether a synthetic date is injected into the data as + * a target leak and if so, how. + * <p/> + * The value of leakType % 3 determines whether the target leak is injected according to + * the following table: + * <p/> + * <table> + * <tr><td valign='top'>0</td><td>No leak injected</td></tr> + * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. This will be a single token and + * is a perfect target leak since each newsgroup is given a different month</td></tr> + * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy HH:mm:ss format. The day varies + * and thus there are more leak symbols that need to be learned. Ultimately this is just + * as big a leak as case 1.</td></tr> + * </table> + * <p/> + * Leaktype also determines what other text will be indexed. If leakType is greater + * than or equal to 6, then neither headers nor text body will be used for features and the leak is the only + * source of data. If leakType is greater than or equal to 3, then subject words will be used as features. + * If leakType is less than 3, then both subject and body text will be used as features. + * <p/> + * A leakType of 0 gives no leak and all textual features. + * <p/> + * See the following table for a summary of commonly used values for leakType + * <p/> + * <table> + * <tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr> + * <tr><td colspan=4><hr></td></tr> + * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr> + * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr> + * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr> + * <tr><td colspan=4><hr></td></tr> + * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr> + * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr> + * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr> + * <tr><td colspan=4><hr></td></tr> + * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr> + * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr> + * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr> + * <tr><td colspan=4><hr></td></tr> + * </table> + */ +public final class TrainNewsGroups { + + private TrainNewsGroups() { + } + + public static void main(String[] args) throws IOException { + File base = new File(args[0]); + + Multiset<String> overallCounts = HashMultiset.create(); + + int leakType = 0; + if (args.length > 1) { + leakType = Integer.parseInt(args[1]); + } + + Dictionary newsGroups = new Dictionary(); + + NewsgroupHelper helper = new NewsgroupHelper(); + helper.getEncoder().setProbes(2); + AdaptiveLogisticRegression learningAlgorithm = + new AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1()); + learningAlgorithm.setInterval(800); + learningAlgorithm.setAveragingWindow(500); + + List<File> files = new ArrayList<>(); + for (File newsgroup : base.listFiles()) { + if (newsgroup.isDirectory()) { + newsGroups.intern(newsgroup.getName()); + files.addAll(Arrays.asList(newsgroup.listFiles())); + } + } + Collections.shuffle(files); + System.out.println(files.size() + " training files"); + SGDInfo info = new SGDInfo(); + + int k = 0; + + for (File file : files) { + String ng = file.getParentFile().getName(); + int actual = newsGroups.intern(ng); + + Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts); + learningAlgorithm.train(actual, v); + + k++; + State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); + + SGDHelper.analyzeState(info, leakType, k, best); + } + learningAlgorithm.close(); + SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, overallCounts); + System.out.println("exiting main"); + + File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group.model"); + ModelSerializer.writeBinary(modelFile.getAbsolutePath(), + learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); + + List<Integer> counts = new ArrayList<>(); + System.out.println("Word counts"); + for (String count : overallCounts.elementSet()) { + counts.add(overallCounts.count(count)); + } + Collections.sort(counts, Ordering.natural().reverse()); + k = 0; + for (Integer count : counts) { + System.out.println(k + "\t" + count); + k++; + if (k > 1000) { + break; + } + } + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java new file mode 100644 index 0000000..7a74289 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.Locale; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.mahout.classifier.ConfusionMatrix; +import org.apache.mahout.classifier.evaluation.Auc; +import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper; +import org.apache.mahout.ep.State; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.stats.OnlineSummarizer; + +/* + * Auc and averageLikelihood are always shown if possible, if the number of target value is more than 2, + * then Auc and entropy matirx are not shown regardless the value of showAuc and showEntropy + * the user passes, because the current implementation does not support them on two value targets. + * */ +public final class ValidateAdaptiveLogistic { + + private static String inputFile; + private static String modelFile; + private static String defaultCategory; + private static boolean showAuc; + private static boolean showScores; + private static boolean showConfusion; + + private ValidateAdaptiveLogistic() { + } + + public static void main(String[] args) throws IOException { + mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + + static void mainToOutput(String[] args, PrintWriter output) throws IOException { + if (parseArgs(args)) { + if (!showAuc && !showConfusion && !showScores) { + showAuc = true; + showConfusion = true; + } + + Auc collector = null; + AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters + .loadFromFile(new File(modelFile)); + CsvRecordFactory csv = lmp.getCsvRecordFactory(); + AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression(); + + if (lmp.getTargetCategories().size() <= 2) { + collector = new Auc(); + } + + OnlineSummarizer slh = new OnlineSummarizer(); + ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(), defaultCategory); + + State<Wrapper, CrossFoldLearner> best = lr.getBest(); + if (best == null) { + output.println("AdaptiveLogisticRegression has not be trained probably."); + return; + } + CrossFoldLearner learner = best.getPayload().getLearner(); + + BufferedReader in = TrainLogistic.open(inputFile); + String line = in.readLine(); + csv.firstLine(line); + line = in.readLine(); + if (showScores) { + output.println("\"target\", \"model-output\", \"log-likelihood\", \"average-likelihood\""); + } + while (line != null) { + Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); + //TODO: How to avoid extra target values not shown in the training process. + int target = csv.processLine(line, v); + double likelihood = learner.logLikelihood(target, v); + double score = learner.classifyFull(v).maxValue(); + + slh.add(likelihood); + cm.addInstance(csv.getTargetString(line), csv.getTargetLabel(target)); + + if (showScores) { + output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f%n", target, + score, learner.logLikelihood(target, v), slh.getMean()); + } + if (collector != null) { + collector.add(target, score); + } + line = in.readLine(); + } + + output.printf(Locale.ENGLISH,"\nLog-likelihood:"); + output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f, Median=%.2f%n", + slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian()); + + if (collector != null) { + output.printf(Locale.ENGLISH, "%nAUC = %.2f%n", collector.auc()); + } + + if (showConfusion) { + output.printf(Locale.ENGLISH, "%n%s%n%n", cm.toString()); + + if (collector != null) { + Matrix m = collector.entropy(); + output.printf(Locale.ENGLISH, + "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), + m.get(1, 0), m.get(0, 1), m.get(1, 1)); + } + } + + } + } + + private static boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help") + .withDescription("print this list").create(); + + Option quiet = builder.withLongName("quiet") + .withDescription("be extra quiet").create(); + + Option auc = builder.withLongName("auc").withDescription("print AUC") + .create(); + Option confusion = builder.withLongName("confusion") + .withDescription("print confusion matrix").create(); + + Option scores = builder.withLongName("scores") + .withDescription("print scores").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFileOption = builder + .withLongName("input") + .withRequired(true) + .withArgument( + argumentBuilder.withName("input").withMaximum(1) + .create()) + .withDescription("where to get validate data").create(); + + Option modelFileOption = builder + .withLongName("model") + .withRequired(true) + .withArgument( + argumentBuilder.withName("model").withMaximum(1) + .create()) + .withDescription("where to get the trained model").create(); + + Option defaultCagetoryOption = builder + .withLongName("defaultCategory") + .withRequired(false) + .withArgument( + argumentBuilder.withName("defaultCategory").withMaximum(1).withDefault("unknown") + .create()) + .withDescription("the default category value to use").create(); + + Group normalArgs = new GroupBuilder().withOption(help) + .withOption(quiet).withOption(auc).withOption(scores) + .withOption(confusion).withOption(inputFileOption) + .withOption(modelFileOption).withOption(defaultCagetoryOption).create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + inputFile = getStringArgument(cmdLine, inputFileOption); + modelFile = getStringArgument(cmdLine, modelFileOption); + defaultCategory = getStringArgument(cmdLine, defaultCagetoryOption); + showAuc = getBooleanArgument(cmdLine, auc); + showScores = getBooleanArgument(cmdLine, scores); + showConfusion = getBooleanArgument(cmdLine, confusion); + + return true; + } + + private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { + return cmdLine.hasOption(option); + } + + private static String getStringArgument(CommandLine cmdLine, Option inputFile) { + return (String) cmdLine.getValue(inputFile); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java new file mode 100644 index 0000000..ab3c861 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd.bankmarketing; + +import com.google.common.collect.Lists; +import org.apache.mahout.classifier.evaluation.Auc; +import org.apache.mahout.classifier.sgd.L1; +import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; + +import java.util.Collections; +import java.util.List; + +/** + * Uses the SGD classifier on the 'Bank marketing' dataset from UCI. + * + * See http://archive.ics.uci.edu/ml/datasets/Bank+Marketing + * + * Learn when people accept or reject an offer from the bank via telephone based on income, age, education and more. + */ +public class BankMarketingClassificationMain { + + public static final int NUM_CATEGORIES = 2; + + public static void main(String[] args) throws Exception { + List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv")); + + double heldOutPercentage = 0.10; + + for (int run = 0; run < 20; run++) { + Collections.shuffle(calls); + int cutoff = (int) (heldOutPercentage * calls.size()); + List<TelephoneCall> test = calls.subList(0, cutoff); + List<TelephoneCall> train = calls.subList(cutoff, calls.size()); + + OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1()) + .learningRate(1) + .alpha(1) + .lambda(0.000001) + .stepOffset(10000) + .decayExponent(0.2); + for (int pass = 0; pass < 20; pass++) { + for (TelephoneCall observation : train) { + lr.train(observation.getTarget(), observation.asVector()); + } + if (pass % 5 == 0) { + Auc eval = new Auc(0.5); + for (TelephoneCall testCall : test) { + eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector())); + } + System.out.printf("%d, %.4f, %.4f\n", pass, lr.currentLearningRate(), eval.auc()); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java new file mode 100644 index 0000000..728ec20 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd.bankmarketing; + +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; +import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; +import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; + +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; + +public class TelephoneCall { + public static final int FEATURES = 100; + private static final ConstantValueEncoder interceptEncoder = new ConstantValueEncoder("intercept"); + private static final FeatureVectorEncoder featureEncoder = new StaticWordValueEncoder("feature"); + + private RandomAccessSparseVector vector; + + private Map<String, String> fields = new LinkedHashMap<>(); + + public TelephoneCall(Iterable<String> fieldNames, Iterable<String> values) { + vector = new RandomAccessSparseVector(FEATURES); + Iterator<String> value = values.iterator(); + interceptEncoder.addToVector("1", vector); + for (String name : fieldNames) { + String fieldValue = value.next(); + fields.put(name, fieldValue); + + switch (name) { + case "age": { + double v = Double.parseDouble(fieldValue); + featureEncoder.addToVector(name, Math.log(v), vector); + break; + } + case "balance": { + double v; + v = Double.parseDouble(fieldValue); + if (v < -2000) { + v = -2000; + } + featureEncoder.addToVector(name, Math.log(v + 2001) - 8, vector); + break; + } + case "duration": { + double v; + v = Double.parseDouble(fieldValue); + featureEncoder.addToVector(name, Math.log(v + 1) - 5, vector); + break; + } + case "pdays": { + double v; + v = Double.parseDouble(fieldValue); + featureEncoder.addToVector(name, Math.log(v + 2), vector); + break; + } + case "job": + case "marital": + case "education": + case "default": + case "housing": + case "loan": + case "contact": + case "campaign": + case "previous": + case "poutcome": + featureEncoder.addToVector(name + ":" + fieldValue, 1, vector); + break; + case "day": + case "month": + case "y": + // ignore these for vectorizing + break; + default: + throw new IllegalArgumentException(String.format("Bad field name: %s", name)); + } + } + } + + public Vector asVector() { + return vector; + } + + public int getTarget() { + return fields.get("y").equals("no") ? 0 : 1; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java new file mode 100644 index 0000000..5ef6490 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd.bankmarketing; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Splitter; +import com.google.common.collect.AbstractIterator; +import com.google.common.io.Resources; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.Iterator; + +/** Parses semi-colon separated data as TelephoneCalls */ +public class TelephoneCallParser implements Iterable<TelephoneCall> { + + private final Splitter onSemi = Splitter.on(";").trimResults(CharMatcher.anyOf("\" ;")); + private String resourceName; + + public TelephoneCallParser(String resourceName) throws IOException { + this.resourceName = resourceName; + } + + @Override + public Iterator<TelephoneCall> iterator() { + try { + return new AbstractIterator<TelephoneCall>() { + BufferedReader input = + new BufferedReader(new InputStreamReader(Resources.getResource(resourceName).openStream())); + Iterable<String> fieldNames = onSemi.split(input.readLine()); + + @Override + protected TelephoneCall computeNext() { + try { + String line = input.readLine(); + if (line == null) { + return endOfData(); + } + + return new TelephoneCall(fieldNames, onSemi.split(line)); + } catch (IOException e) { + throw new RuntimeException("Error reading data", e); + } + } + }; + } catch (IOException e) { + throw new RuntimeException("Error reading data", e); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java new file mode 100644 index 0000000..a0b845f --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.display; + +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.PathFilter; + +final class ClustersFilter implements PathFilter { + + @Override + public boolean accept(Path path) { + String pathString = path.toString(); + return pathString.contains("/clusters-"); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java new file mode 100644 index 0000000..50dba99 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.display; + +import java.awt.BasicStroke; +import java.awt.Color; +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.canopy.CanopyDriver; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.math.DenseVector; + +/** + * Java desktop graphics class that runs canopy clustering and displays the results. + * This class generates random data and clusters it. + */ +@Deprecated +public class DisplayCanopy extends DisplayClustering { + + DisplayCanopy() { + initialize(); + this.setTitle("Canopy Clusters (>" + (int) (significance * 100) + "% of population)"); + } + + @Override + public void paint(Graphics g) { + plotSampleData((Graphics2D) g); + plotClusters((Graphics2D) g); + } + + protected static void plotClusters(Graphics2D g2) { + int cx = CLUSTERS.size() - 1; + for (List<Cluster> clusters : CLUSTERS) { + for (Cluster cluster : clusters) { + if (isSignificant(cluster)) { + g2.setStroke(new BasicStroke(1)); + g2.setColor(Color.BLUE); + double[] t1 = {T1, T1}; + plotEllipse(g2, cluster.getCenter(), new DenseVector(t1)); + double[] t2 = {T2, T2}; + plotEllipse(g2, cluster.getCenter(), new DenseVector(t2)); + g2.setColor(COLORS[Math.min(DisplayClustering.COLORS.length - 1, cx)]); + g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1)); + plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3)); + } + } + cx--; + } + } + + public static void main(String[] args) throws Exception { + Path samples = new Path("samples"); + Path output = new Path("output"); + Configuration conf = new Configuration(); + HadoopUtil.delete(conf, samples); + HadoopUtil.delete(conf, output); + RandomUtils.useTestSeed(); + generateSamples(); + writeSampleData(samples); + CanopyDriver.buildClusters(conf, samples, output, new ManhattanDistanceMeasure(), T1, T2, 0, true); + loadClustersWritable(output); + + new DisplayCanopy(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java new file mode 100644 index 0000000..ad85c6a --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.display; + +import java.awt.*; +import java.awt.event.WindowAdapter; +import java.awt.event.WindowEvent; +import java.awt.geom.AffineTransform; +import java.awt.geom.Ellipse2D; +import java.awt.geom.Rectangle2D; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.mahout.clustering.AbstractCluster; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.UncommonDistributions; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class DisplayClustering extends Frame { + + private static final Logger log = LoggerFactory.getLogger(DisplayClustering.class); + + protected static final int DS = 72; // default scale = 72 pixels per inch + + protected static final int SIZE = 8; // screen size in inches + + private static final Collection<Vector> SAMPLE_PARAMS = new ArrayList<>(); + + protected static final List<VectorWritable> SAMPLE_DATA = new ArrayList<>(); + + protected static final List<List<Cluster>> CLUSTERS = new ArrayList<>(); + + static final Color[] COLORS = { Color.red, Color.orange, Color.yellow, Color.green, Color.blue, Color.magenta, + Color.lightGray }; + + protected static final double T1 = 3.0; + + protected static final double T2 = 2.8; + + static double significance = 0.05; + + protected static int res; // screen resolution + + public DisplayClustering() { + initialize(); + this.setTitle("Sample Data"); + } + + public void initialize() { + // Get screen resolution + res = Toolkit.getDefaultToolkit().getScreenResolution(); + + // Set Frame size in inches + this.setSize(SIZE * res, SIZE * res); + this.setVisible(true); + this.setTitle("Asymmetric Sample Data"); + + // Window listener to terminate program. + this.addWindowListener(new WindowAdapter() { + @Override + public void windowClosing(WindowEvent e) { + System.exit(0); + } + }); + } + + public static void main(String[] args) throws Exception { + RandomUtils.useTestSeed(); + generateSamples(); + new DisplayClustering(); + } + + // Override the paint() method + @Override + public void paint(Graphics g) { + Graphics2D g2 = (Graphics2D) g; + plotSampleData(g2); + plotSampleParameters(g2); + plotClusters(g2); + } + + protected static void plotClusters(Graphics2D g2) { + int cx = CLUSTERS.size() - 1; + for (List<Cluster> clusters : CLUSTERS) { + g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1)); + g2.setColor(COLORS[Math.min(COLORS.length - 1, cx--)]); + for (Cluster cluster : clusters) { + plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3)); + } + } + } + + protected static void plotSampleParameters(Graphics2D g2) { + Vector v = new DenseVector(2); + Vector dv = new DenseVector(2); + g2.setColor(Color.RED); + for (Vector param : SAMPLE_PARAMS) { + v.set(0, param.get(0)); + v.set(1, param.get(1)); + dv.set(0, param.get(2) * 3); + dv.set(1, param.get(3) * 3); + plotEllipse(g2, v, dv); + } + } + + protected static void plotSampleData(Graphics2D g2) { + double sx = (double) res / DS; + g2.setTransform(AffineTransform.getScaleInstance(sx, sx)); + + // plot the axes + g2.setColor(Color.BLACK); + Vector dv = new DenseVector(2).assign(SIZE / 2.0); + plotRectangle(g2, new DenseVector(2).assign(2), dv); + plotRectangle(g2, new DenseVector(2).assign(-2), dv); + + // plot the sample data + g2.setColor(Color.DARK_GRAY); + dv.assign(0.03); + for (VectorWritable v : SAMPLE_DATA) { + plotRectangle(g2, v.get(), dv); + } + } + + /** + * This method plots points and colors them according to their cluster + * membership, rather than drawing ellipses. + * + * As of commit, this method is used only by K-means spectral clustering. + * Since the cluster assignments are set within the eigenspace of the data, it + * is not inherent that the original data cluster as they would in K-means: + * that is, as symmetric gaussian mixtures. + * + * Since Spectral K-Means uses K-Means to cluster the eigenspace data, the raw + * output is not directly usable. Rather, the cluster assignments from the raw + * output need to be transferred back to the original data. As such, this + * method will read the SequenceFile cluster results of K-means and transfer + * the cluster assignments to the original data, coloring them appropriately. + * + * @param g2 + * @param data + */ + protected static void plotClusteredSampleData(Graphics2D g2, Path data) { + double sx = (double) res / DS; + g2.setTransform(AffineTransform.getScaleInstance(sx, sx)); + + g2.setColor(Color.BLACK); + Vector dv = new DenseVector(2).assign(SIZE / 2.0); + plotRectangle(g2, new DenseVector(2).assign(2), dv); + plotRectangle(g2, new DenseVector(2).assign(-2), dv); + + // plot the sample data, colored according to the cluster they belong to + dv.assign(0.03); + + Path clusteredPointsPath = new Path(data, "clusteredPoints"); + Path inputPath = new Path(clusteredPointsPath, "part-m-00000"); + Map<Integer,Color> colors = new HashMap<>(); + int point = 0; + for (Pair<IntWritable,WeightedVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedVectorWritable>( + inputPath, new Configuration())) { + int clusterId = record.getFirst().get(); + VectorWritable v = SAMPLE_DATA.get(point++); + Integer key = clusterId; + if (!colors.containsKey(key)) { + colors.put(key, COLORS[Math.min(COLORS.length - 1, colors.size())]); + } + plotClusteredRectangle(g2, v.get(), dv, colors.get(key)); + } + } + + /** + * Identical to plotRectangle(), but with the option of setting the color of + * the rectangle's stroke. + * + * NOTE: This should probably be refactored with plotRectangle() since most of + * the code here is direct copy/paste from that method. + * + * @param g2 + * A Graphics2D context. + * @param v + * A vector for the rectangle's center. + * @param dv + * A vector for the rectangle's dimensions. + * @param color + * The color of the rectangle's stroke. + */ + protected static void plotClusteredRectangle(Graphics2D g2, Vector v, Vector dv, Color color) { + double[] flip = {1, -1}; + Vector v2 = v.times(new DenseVector(flip)); + v2 = v2.minus(dv.divide(2)); + int h = SIZE / 2; + double x = v2.get(0) + h; + double y = v2.get(1) + h; + + g2.setStroke(new BasicStroke(1)); + g2.setColor(color); + g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS)); + } + + /** + * Draw a rectangle on the graphics context + * + * @param g2 + * a Graphics2D context + * @param v + * a Vector of rectangle center + * @param dv + * a Vector of rectangle dimensions + */ + protected static void plotRectangle(Graphics2D g2, Vector v, Vector dv) { + double[] flip = {1, -1}; + Vector v2 = v.times(new DenseVector(flip)); + v2 = v2.minus(dv.divide(2)); + int h = SIZE / 2; + double x = v2.get(0) + h; + double y = v2.get(1) + h; + g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS)); + } + + /** + * Draw an ellipse on the graphics context + * + * @param g2 + * a Graphics2D context + * @param v + * a Vector of ellipse center + * @param dv + * a Vector of ellipse dimensions + */ + protected static void plotEllipse(Graphics2D g2, Vector v, Vector dv) { + double[] flip = {1, -1}; + Vector v2 = v.times(new DenseVector(flip)); + v2 = v2.minus(dv.divide(2)); + int h = SIZE / 2; + double x = v2.get(0) + h; + double y = v2.get(1) + h; + g2.draw(new Ellipse2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS)); + } + + protected static void generateSamples() { + generateSamples(500, 1, 1, 3); + generateSamples(300, 1, 0, 0.5); + generateSamples(300, 0, 2, 0.1); + } + + protected static void generate2dSamples() { + generate2dSamples(500, 1, 1, 3, 1); + generate2dSamples(300, 1, 0, 0.5, 1); + generate2dSamples(300, 0, 2, 0.1, 0.5); + } + + /** + * Generate random samples and add them to the sampleData + * + * @param num + * int number of samples to generate + * @param mx + * double x-value of the sample mean + * @param my + * double y-value of the sample mean + * @param sd + * double standard deviation of the samples + */ + protected static void generateSamples(int num, double mx, double my, double sd) { + double[] params = {mx, my, sd, sd}; + SAMPLE_PARAMS.add(new DenseVector(params)); + log.info("Generating {} samples m=[{}, {}] sd={}", num, mx, my, sd); + for (int i = 0; i < num; i++) { + SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sd), + UncommonDistributions.rNorm(my, sd)}))); + } + } + + protected static void writeSampleData(Path output) throws IOException { + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(output.toUri(), conf); + + try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, output, Text.class, VectorWritable.class)) { + int i = 0; + for (VectorWritable vw : SAMPLE_DATA) { + writer.append(new Text("sample_" + i++), vw); + } + } + } + + protected static List<Cluster> readClustersWritable(Path clustersIn) { + List<Cluster> clusters = new ArrayList<>(); + Configuration conf = new Configuration(); + for (ClusterWritable value : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST, + PathFilters.logsCRCFilter(), conf)) { + Cluster cluster = value.getValue(); + log.info( + "Reading Cluster:{} center:{} numPoints:{} radius:{}", + cluster.getId(), AbstractCluster.formatVector(cluster.getCenter(), null), + cluster.getNumObservations(), AbstractCluster.formatVector(cluster.getRadius(), null)); + clusters.add(cluster); + } + return clusters; + } + + protected static void loadClustersWritable(Path output) throws IOException { + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(output.toUri(), conf); + for (FileStatus s : fs.listStatus(output, new ClustersFilter())) { + List<Cluster> clusters = readClustersWritable(s.getPath()); + CLUSTERS.add(clusters); + } + } + + /** + * Generate random samples and add them to the sampleData + * + * @param num + * int number of samples to generate + * @param mx + * double x-value of the sample mean + * @param my + * double y-value of the sample mean + * @param sdx + * double x-value standard deviation of the samples + * @param sdy + * double y-value standard deviation of the samples + */ + protected static void generate2dSamples(int num, double mx, double my, double sdx, double sdy) { + double[] params = {mx, my, sdx, sdy}; + SAMPLE_PARAMS.add(new DenseVector(params)); + log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy); + for (int i = 0; i < num; i++) { + SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sdx), + UncommonDistributions.rNorm(my, sdy)}))); + } + } + + protected static boolean isSignificant(Cluster cluster) { + return (double) cluster.getNumObservations() / SAMPLE_DATA.size() > significance; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java new file mode 100644 index 0000000..f8ce7c7 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java @@ -0,0 +1,110 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.display; + +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver; +import org.apache.mahout.clustering.fuzzykmeans.SoftCluster; +import org.apache.mahout.clustering.iterator.ClusterIterator; +import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy; +import org.apache.mahout.clustering.kmeans.RandomSeedGenerator; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.math.Vector; + +import com.google.common.collect.Lists; + +public class DisplayFuzzyKMeans extends DisplayClustering { + + DisplayFuzzyKMeans() { + initialize(); + this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); + } + + // Override the paint() method + @Override + public void paint(Graphics g) { + plotSampleData((Graphics2D) g); + plotClusters((Graphics2D) g); + } + + public static void main(String[] args) throws Exception { + DistanceMeasure measure = new ManhattanDistanceMeasure(); + + Path samples = new Path("samples"); + Path output = new Path("output"); + Configuration conf = new Configuration(); + HadoopUtil.delete(conf, output); + HadoopUtil.delete(conf, samples); + RandomUtils.useTestSeed(); + DisplayClustering.generateSamples(); + writeSampleData(samples); + boolean runClusterer = true; + int maxIterations = 10; + float threshold = 0.001F; + float m = 1.1F; + if (runClusterer) { + runSequentialFuzzyKClusterer(conf, samples, output, measure, maxIterations, m, threshold); + } else { + int numClusters = 3; + runSequentialFuzzyKClassifier(conf, samples, output, measure, numClusters, maxIterations, m, threshold); + } + new DisplayFuzzyKMeans(); + } + + private static void runSequentialFuzzyKClassifier(Configuration conf, Path samples, Path output, + DistanceMeasure measure, int numClusters, int maxIterations, float m, double threshold) throws IOException { + Collection<Vector> points = Lists.newArrayList(); + for (int i = 0; i < numClusters; i++) { + points.add(SAMPLE_DATA.get(i).get()); + } + List<Cluster> initialClusters = Lists.newArrayList(); + int id = 0; + for (Vector point : points) { + initialClusters.add(new SoftCluster(point, id++, measure)); + } + ClusterClassifier prior = new ClusterClassifier(initialClusters, new FuzzyKMeansClusteringPolicy(m, threshold)); + Path priorPath = new Path(output, "classifier-0"); + prior.writeToSeqFiles(priorPath); + + ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations); + loadClustersWritable(output); + } + + private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path output, + DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException, + ClassNotFoundException, InterruptedException { + Path clustersIn = new Path(output, "random-seeds"); + RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure); + FuzzyKMeansDriver.run(samples, clustersIn, output, threshold, maxIterations, m, true, true, threshold, + true); + + loadClustersWritable(output); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java new file mode 100644 index 0000000..336d69e --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.display; + +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.clustering.iterator.ClusterIterator; +import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy; +import org.apache.mahout.clustering.kmeans.KMeansDriver; +import org.apache.mahout.clustering.kmeans.RandomSeedGenerator; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.math.Vector; + +import com.google.common.collect.Lists; + +public class DisplayKMeans extends DisplayClustering { + + DisplayKMeans() { + initialize(); + this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); + } + + public static void main(String[] args) throws Exception { + DistanceMeasure measure = new ManhattanDistanceMeasure(); + Path samples = new Path("samples"); + Path output = new Path("output"); + Configuration conf = new Configuration(); + HadoopUtil.delete(conf, samples); + HadoopUtil.delete(conf, output); + + RandomUtils.useTestSeed(); + generateSamples(); + writeSampleData(samples); + boolean runClusterer = true; + double convergenceDelta = 0.001; + int numClusters = 3; + int maxIterations = 10; + if (runClusterer) { + runSequentialKMeansClusterer(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta); + } else { + runSequentialKMeansClassifier(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta); + } + new DisplayKMeans(); + } + + private static void runSequentialKMeansClassifier(Configuration conf, Path samples, Path output, + DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta) throws IOException { + Collection<Vector> points = Lists.newArrayList(); + for (int i = 0; i < numClusters; i++) { + points.add(SAMPLE_DATA.get(i).get()); + } + List<Cluster> initialClusters = Lists.newArrayList(); + int id = 0; + for (Vector point : points) { + initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure)); + } + ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta)); + Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR); + prior.writeToSeqFiles(priorPath); + + ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations); + loadClustersWritable(output); + } + + private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path output, + DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta) + throws IOException, InterruptedException, ClassNotFoundException { + Path clustersIn = new Path(output, "random-seeds"); + RandomSeedGenerator.buildRandom(conf, samples, clustersIn, numClusters, measure); + KMeansDriver.run(samples, clustersIn, output, convergenceDelta, maxIterations, true, 0.0, true); + loadClustersWritable(output); + } + + // Override the paint() method + @Override + public void paint(Graphics g) { + plotSampleData((Graphics2D) g); + plotClusters((Graphics2D) g); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java new file mode 100644 index 0000000..2b70749 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.display; + +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.io.BufferedWriter; +import java.io.FileWriter; +import java.io.Writer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; + +public class DisplaySpectralKMeans extends DisplayClustering { + + protected static final String SAMPLES = "samples"; + protected static final String OUTPUT = "output"; + protected static final String TEMP = "tmp"; + protected static final String AFFINITIES = "affinities"; + + DisplaySpectralKMeans() { + initialize(); + setTitle("Spectral k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); + } + + public static void main(String[] args) throws Exception { + DistanceMeasure measure = new ManhattanDistanceMeasure(); + Path samples = new Path(SAMPLES); + Path output = new Path(OUTPUT); + Path tempDir = new Path(TEMP); + Configuration conf = new Configuration(); + HadoopUtil.delete(conf, samples); + HadoopUtil.delete(conf, output); + + RandomUtils.useTestSeed(); + DisplayClustering.generateSamples(); + writeSampleData(samples); + Path affinities = new Path(output, AFFINITIES); + FileSystem fs = FileSystem.get(output.toUri(), conf); + if (!fs.exists(output)) { + fs.mkdirs(output); + } + + try (Writer writer = new BufferedWriter(new FileWriter(affinities.toString()))){ + for (int i = 0; i < SAMPLE_DATA.size(); i++) { + for (int j = 0; j < SAMPLE_DATA.size(); j++) { + writer.write(i + "," + j + ',' + measure.distance(SAMPLE_DATA.get(i).get(), + SAMPLE_DATA.get(j).get()) + '\n'); + } + } + } + + int maxIter = 10; + double convergenceDelta = 0.001; + SpectralKMeansDriver.run(new Configuration(), affinities, output, SAMPLE_DATA.size(), 3, measure, + convergenceDelta, maxIter, tempDir); + new DisplaySpectralKMeans(); + } + + @Override + public void paint(Graphics g) { + plotClusteredSampleData((Graphics2D) g, new Path(new Path(OUTPUT), "kmeans_out")); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/README.txt ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/README.txt b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/README.txt new file mode 100644 index 0000000..470c16c --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/display/README.txt @@ -0,0 +1,22 @@ +The following classes can be run without parameters to generate a sample data set and +run the reference clustering implementations over them: + +DisplayClustering - generates 1000 samples from three, symmetric distributions. This is the same + data set that is used by the following clustering programs. It displays the points on a screen + and superimposes the model parameters that were used to generate the points. You can edit the + generateSamples() method to change the sample points used by these programs. + + * DisplayCanopy - uses Canopy clustering + * DisplayKMeans - uses k-Means clustering + * DisplayFuzzyKMeans - uses Fuzzy k-Means clustering + + * NOTE: some of these programs display the sample points and then superimpose all of the clusters + from each iteration. The last iteration's clusters are in bold red and the previous several are + colored (orange, yellow, green, blue, violet) in order after which all earlier clusters are in + light grey. This helps to visualize how the clusters converge upon a solution over multiple + iterations. + * NOTE: by changing the parameter values (k, ALPHA_0, numIterations) and the display SIGNIFICANCE + you can obtain different results. + + + \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java new file mode 100644 index 0000000..c29cbc4 --- /dev/null +++ b/community/mahout-mr/mr-examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.tools; + +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.List; + +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.stats.OnlineSummarizer; + +public class ClusterQualitySummarizer extends AbstractJob { + private String outputFile; + + private PrintWriter fileOut; + + private String trainFile; + private String testFile; + private String centroidFile; + private String centroidCompareFile; + private boolean mahoutKMeansFormat; + private boolean mahoutKMeansFormatCompare; + + private DistanceMeasure distanceMeasure = new SquaredEuclideanDistanceMeasure(); + + public void printSummaries(List<OnlineSummarizer> summarizers, String type) { + printSummaries(summarizers, type, fileOut); + } + + public static void printSummaries(List<OnlineSummarizer> summarizers, String type, PrintWriter fileOut) { + double maxDistance = 0; + for (int i = 0; i < summarizers.size(); ++i) { + OnlineSummarizer summarizer = summarizers.get(i); + if (summarizer.getCount() > 1) { + maxDistance = Math.max(maxDistance, summarizer.getMax()); + System.out.printf("Average distance in cluster %d [%d]: %f\n", i, summarizer.getCount(), summarizer.getMean()); + // If there is just one point in the cluster, quartiles cannot be estimated. We'll just assume all the quartiles + // equal the only value. + if (fileOut != null) { + fileOut.printf("%d,%f,%f,%f,%f,%f,%f,%f,%d,%s\n", i, summarizer.getMean(), + summarizer.getSD(), + summarizer.getQuartile(0), + summarizer.getQuartile(1), + summarizer.getQuartile(2), + summarizer.getQuartile(3), + summarizer.getQuartile(4), summarizer.getCount(), type); + } + } else { + System.out.printf("Cluster %d is has %d data point. Need atleast 2 data points in a cluster for" + + " OnlineSummarizer.\n", i, summarizer.getCount()); + } + } + System.out.printf("Num clusters: %d; maxDistance: %f\n", summarizers.size(), maxDistance); + } + + public int run(String[] args) throws IOException { + if (!parseArgs(args)) { + return -1; + } + + Configuration conf = new Configuration(); + try { + fileOut = new PrintWriter(new FileOutputStream(outputFile)); + fileOut.printf("cluster,distance.mean,distance.sd,distance.q0,distance.q1,distance.q2,distance.q3," + + "distance.q4,count,is.train\n"); + + // Reading in the centroids (both pairs, if they exist). + List<Centroid> centroids; + List<Centroid> centroidsCompare = null; + if (mahoutKMeansFormat) { + SequenceFileDirValueIterable<ClusterWritable> clusterIterable = + new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf); + centroids = Lists.newArrayList(IOUtils.getCentroidsFromClusterWritableIterable(clusterIterable)); + } else { + SequenceFileDirValueIterable<CentroidWritable> centroidIterable = + new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf); + centroids = Lists.newArrayList(IOUtils.getCentroidsFromCentroidWritableIterable(centroidIterable)); + } + + if (centroidCompareFile != null) { + if (mahoutKMeansFormatCompare) { + SequenceFileDirValueIterable<ClusterWritable> clusterCompareIterable = + new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf); + centroidsCompare = Lists.newArrayList( + IOUtils.getCentroidsFromClusterWritableIterable(clusterCompareIterable)); + } else { + SequenceFileDirValueIterable<CentroidWritable> centroidCompareIterable = + new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf); + centroidsCompare = Lists.newArrayList( + IOUtils.getCentroidsFromCentroidWritableIterable(centroidCompareIterable)); + } + } + + // Reading in the "training" set. + SequenceFileDirValueIterable<VectorWritable> trainIterable = + new SequenceFileDirValueIterable<>(new Path(trainFile), PathType.GLOB, conf); + Iterable<Vector> trainDatapoints = IOUtils.getVectorsFromVectorWritableIterable(trainIterable); + Iterable<Vector> datapoints = trainDatapoints; + + printSummaries(ClusteringUtils.summarizeClusterDistances(trainDatapoints, centroids, + new SquaredEuclideanDistanceMeasure()), "train"); + + // Also adding in the "test" set. + if (testFile != null) { + SequenceFileDirValueIterable<VectorWritable> testIterable = + new SequenceFileDirValueIterable<>(new Path(testFile), PathType.GLOB, conf); + Iterable<Vector> testDatapoints = IOUtils.getVectorsFromVectorWritableIterable(testIterable); + + printSummaries(ClusteringUtils.summarizeClusterDistances(testDatapoints, centroids, + new SquaredEuclideanDistanceMeasure()), "test"); + + datapoints = Iterables.concat(trainDatapoints, testDatapoints); + } + + // At this point, all train/test CSVs have been written. We now compute quality metrics. + List<OnlineSummarizer> summaries = + ClusteringUtils.summarizeClusterDistances(datapoints, centroids, distanceMeasure); + List<OnlineSummarizer> compareSummaries = null; + if (centroidsCompare != null) { + compareSummaries = ClusteringUtils.summarizeClusterDistances(datapoints, centroidsCompare, distanceMeasure); + } + System.out.printf("[Dunn Index] First: %f", ClusteringUtils.dunnIndex(centroids, distanceMeasure, summaries)); + if (compareSummaries != null) { + System.out.printf(" Second: %f\n", ClusteringUtils.dunnIndex(centroidsCompare, distanceMeasure, compareSummaries)); + } else { + System.out.printf("\n"); + } + System.out.printf("[Davies-Bouldin Index] First: %f", + ClusteringUtils.daviesBouldinIndex(centroids, distanceMeasure, summaries)); + if (compareSummaries != null) { + System.out.printf(" Second: %f\n", + ClusteringUtils.daviesBouldinIndex(centroidsCompare, distanceMeasure, compareSummaries)); + } else { + System.out.printf("\n"); + } + } catch (IOException e) { + System.out.println(e.getMessage()); + } finally { + Closeables.close(fileOut, false); + } + return 0; + } + + private boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help").withDescription("print this list").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFileOption = builder.withLongName("input") + .withShortName("i") + .withRequired(true) + .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) + .withDescription("where to get seq files with the vectors (training set)") + .create(); + + Option testInputFileOption = builder.withLongName("testInput") + .withShortName("itest") + .withArgument(argumentBuilder.withName("testInput").withMaximum(1).create()) + .withDescription("where to get seq files with the vectors (test set)") + .create(); + + Option centroidsFileOption = builder.withLongName("centroids") + .withShortName("c") + .withRequired(true) + .withArgument(argumentBuilder.withName("centroids").withMaximum(1).create()) + .withDescription("where to get seq files with the centroids (from Mahout KMeans or StreamingKMeansDriver)") + .create(); + + Option centroidsCompareFileOption = builder.withLongName("centroidsCompare") + .withShortName("cc") + .withRequired(false) + .withArgument(argumentBuilder.withName("centroidsCompare").withMaximum(1).create()) + .withDescription("where to get seq files with the second set of centroids (from Mahout KMeans or " + + "StreamingKMeansDriver)") + .create(); + + Option outputFileOption = builder.withLongName("output") + .withShortName("o") + .withRequired(true) + .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) + .withDescription("where to dump the CSV file with the results") + .create(); + + Option mahoutKMeansFormatOption = builder.withLongName("mahoutkmeansformat") + .withShortName("mkm") + .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs") + .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create()) + .create(); + + Option mahoutKMeansCompareFormatOption = builder.withLongName("mahoutkmeansformatCompare") + .withShortName("mkmc") + .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs") + .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create()) + .create(); + + Group normalArgs = new GroupBuilder() + .withOption(help) + .withOption(inputFileOption) + .withOption(testInputFileOption) + .withOption(outputFileOption) + .withOption(centroidsFileOption) + .withOption(centroidsCompareFileOption) + .withOption(mahoutKMeansFormatOption) + .withOption(mahoutKMeansCompareFormatOption) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 150)); + + CommandLine cmdLine = parser.parseAndHelp(args); + if (cmdLine == null) { + return false; + } + + trainFile = (String) cmdLine.getValue(inputFileOption); + if (cmdLine.hasOption(testInputFileOption)) { + testFile = (String) cmdLine.getValue(testInputFileOption); + } + centroidFile = (String) cmdLine.getValue(centroidsFileOption); + if (cmdLine.hasOption(centroidsCompareFileOption)) { + centroidCompareFile = (String) cmdLine.getValue(centroidsCompareFileOption); + } + outputFile = (String) cmdLine.getValue(outputFileOption); + if (cmdLine.hasOption(mahoutKMeansFormatOption)) { + mahoutKMeansFormat = true; + } + if (cmdLine.hasOption(mahoutKMeansCompareFormatOption)) { + mahoutKMeansFormatCompare = true; + } + return true; + } + + public static void main(String[] args) throws IOException { + new ClusterQualitySummarizer().run(args); + } +}
