http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java deleted file mode 100644 index f4b8bcb..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.io.Resources; -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.commons.io.Charsets; -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; - -import java.io.BufferedReader; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.OutputStream; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -/** - * Train a logistic regression for the examples from Chapter 13 of Mahout in Action - */ -public final class TrainLogistic { - - private static String inputFile; - private static String outputFile; - private static LogisticModelParameters lmp; - private static int passes; - private static boolean scores; - private static OnlineLogisticRegression model; - - private TrainLogistic() { - } - - public static void main(String[] args) throws Exception { - mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); - } - - static void mainToOutput(String[] args, PrintWriter output) throws Exception { - if (parseArgs(args)) { - double logPEstimate = 0; - int samples = 0; - - CsvRecordFactory csv = lmp.getCsvRecordFactory(); - OnlineLogisticRegression lr = lmp.createRegression(); - for (int pass = 0; pass < passes; pass++) { - try (BufferedReader in = open(inputFile)) { - // read variable names - csv.firstLine(in.readLine()); - - String line = in.readLine(); - while (line != null) { - // for each new line, get target and predictors - Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); - int targetValue = csv.processLine(line, input); - - // check performance while this is still news - double logP = lr.logLikelihood(targetValue, input); - if (!Double.isInfinite(logP)) { - if (samples < 20) { - logPEstimate = (samples * logPEstimate + logP) / (samples + 1); - } else { - logPEstimate = 0.95 * logPEstimate + 0.05 * logP; - } - samples++; - } - double p = lr.classifyScalar(input); - if (scores) { - output.printf(Locale.ENGLISH, "%10d %2d %10.2f %2.4f %10.4f %10.4f%n", - samples, targetValue, lr.currentLearningRate(), p, logP, logPEstimate); - } - - // now update model - lr.train(targetValue, input); - - line = in.readLine(); - } - } - } - - try (OutputStream modelOutput = new FileOutputStream(outputFile)) { - lmp.saveTo(modelOutput); - } - - output.println(lmp.getNumFeatures()); - output.println(lmp.getTargetVariable() + " ~ "); - String sep = ""; - for (String v : csv.getTraceDictionary().keySet()) { - double weight = predictorWeight(lr, 0, csv, v); - if (weight != 0) { - output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); - sep = " + "; - } - } - output.printf("%n"); - model = lr; - for (int row = 0; row < lr.getBeta().numRows(); row++) { - for (String key : csv.getTraceDictionary().keySet()) { - double weight = predictorWeight(lr, row, csv, key); - if (weight != 0) { - output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); - } - } - for (int column = 0; column < lr.getBeta().numCols(); column++) { - output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); - } - output.println(); - } - } - } - - private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) { - double weight = 0; - for (Integer column : csv.getTraceDictionary().get(predictor)) { - weight += lr.getBeta().get(row, column); - } - return weight; - } - - private static boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help").withDescription("print this list").create(); - - Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create(); - Option scores = builder.withLongName("scores").withDescription("output score diagnostics during training").create(); - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option inputFile = builder.withLongName("input") - .withRequired(true) - .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) - .withDescription("where to get training data") - .create(); - - Option outputFile = builder.withLongName("output") - .withRequired(true) - .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) - .withDescription("where to get training data") - .create(); - - Option predictors = builder.withLongName("predictors") - .withRequired(true) - .withArgument(argumentBuilder.withName("p").create()) - .withDescription("a list of predictor variables") - .create(); - - Option types = builder.withLongName("types") - .withRequired(true) - .withArgument(argumentBuilder.withName("t").create()) - .withDescription("a list of predictor variable types (numeric, word, or text)") - .create(); - - Option target = builder.withLongName("target") - .withRequired(true) - .withArgument(argumentBuilder.withName("target").withMaximum(1).create()) - .withDescription("the name of the target variable") - .create(); - - Option features = builder.withLongName("features") - .withArgument( - argumentBuilder.withName("numFeatures") - .withDefault("1000") - .withMaximum(1).create()) - .withDescription("the number of internal hashed features to use") - .create(); - - Option passes = builder.withLongName("passes") - .withArgument( - argumentBuilder.withName("passes") - .withDefault("2") - .withMaximum(1).create()) - .withDescription("the number of times to pass over the input data") - .create(); - - Option lambda = builder.withLongName("lambda") - .withArgument(argumentBuilder.withName("lambda").withDefault("1e-4").withMaximum(1).create()) - .withDescription("the amount of coefficient decay to use") - .create(); - - Option rate = builder.withLongName("rate") - .withArgument(argumentBuilder.withName("learningRate").withDefault("1e-3").withMaximum(1).create()) - .withDescription("the learning rate") - .create(); - - Option noBias = builder.withLongName("noBias") - .withDescription("don't include a bias term") - .create(); - - Option targetCategories = builder.withLongName("categories") - .withRequired(true) - .withArgument(argumentBuilder.withName("number").withMaximum(1).create()) - .withDescription("the number of target categories to be considered") - .create(); - - Group normalArgs = new GroupBuilder() - .withOption(help) - .withOption(quiet) - .withOption(inputFile) - .withOption(outputFile) - .withOption(target) - .withOption(targetCategories) - .withOption(predictors) - .withOption(types) - .withOption(passes) - .withOption(lambda) - .withOption(rate) - .withOption(noBias) - .withOption(features) - .create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); - CommandLine cmdLine = parser.parseAndHelp(args); - - if (cmdLine == null) { - return false; - } - - TrainLogistic.inputFile = getStringArgument(cmdLine, inputFile); - TrainLogistic.outputFile = getStringArgument(cmdLine, outputFile); - - List<String> typeList = new ArrayList<>(); - for (Object x : cmdLine.getValues(types)) { - typeList.add(x.toString()); - } - - List<String> predictorList = new ArrayList<>(); - for (Object x : cmdLine.getValues(predictors)) { - predictorList.add(x.toString()); - } - - lmp = new LogisticModelParameters(); - lmp.setTargetVariable(getStringArgument(cmdLine, target)); - lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories)); - lmp.setNumFeatures(getIntegerArgument(cmdLine, features)); - lmp.setUseBias(!getBooleanArgument(cmdLine, noBias)); - lmp.setTypeMap(predictorList, typeList); - - lmp.setLambda(getDoubleArgument(cmdLine, lambda)); - lmp.setLearningRate(getDoubleArgument(cmdLine, rate)); - - TrainLogistic.scores = getBooleanArgument(cmdLine, scores); - TrainLogistic.passes = getIntegerArgument(cmdLine, passes); - - return true; - } - - private static String getStringArgument(CommandLine cmdLine, Option inputFile) { - return (String) cmdLine.getValue(inputFile); - } - - private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { - return cmdLine.hasOption(option); - } - - private static int getIntegerArgument(CommandLine cmdLine, Option features) { - return Integer.parseInt((String) cmdLine.getValue(features)); - } - - private static double getDoubleArgument(CommandLine cmdLine, Option op) { - return Double.parseDouble((String) cmdLine.getValue(op)); - } - - public static OnlineLogisticRegression getModel() { - return model; - } - - public static LogisticModelParameters getParameters() { - return lmp; - } - - static BufferedReader open(String inputFile) throws IOException { - InputStream in; - try { - in = Resources.getResource(inputFile).openStream(); - } catch (IllegalArgumentException e) { - in = new FileInputStream(new File(inputFile)); - } - return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8)); - } -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java deleted file mode 100644 index 632b32c..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; -import com.google.common.collect.Ordering; -import org.apache.mahout.classifier.NewsgroupHelper; -import org.apache.mahout.ep.State; -import org.apache.mahout.math.Vector; -import org.apache.mahout.vectorizer.encoders.Dictionary; - -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -/** - * Reads and trains an adaptive logistic regression model on the 20 newsgroups data. - * The first command line argument gives the path of the directory holding the training - * data. The optional second argument, leakType, defines which classes of features to use. - * Importantly, leakType controls whether a synthetic date is injected into the data as - * a target leak and if so, how. - * <p/> - * The value of leakType % 3 determines whether the target leak is injected according to - * the following table: - * <p/> - * <table> - * <tr><td valign='top'>0</td><td>No leak injected</td></tr> - * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. This will be a single token and - * is a perfect target leak since each newsgroup is given a different month</td></tr> - * <tr><td valign='top'>2</td><td>Synthetic date injected in dd-MMM-yyyy HH:mm:ss format. The day varies - * and thus there are more leak symbols that need to be learned. Ultimately this is just - * as big a leak as case 1.</td></tr> - * </table> - * <p/> - * Leaktype also determines what other text will be indexed. If leakType is greater - * than or equal to 6, then neither headers nor text body will be used for features and the leak is the only - * source of data. If leakType is greater than or equal to 3, then subject words will be used as features. - * If leakType is less than 3, then both subject and body text will be used as features. - * <p/> - * A leakType of 0 gives no leak and all textual features. - * <p/> - * See the following table for a summary of commonly used values for leakType - * <p/> - * <table> - * <tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr> - * <tr><td colspan=4><hr></td></tr> - * <tr><td>0</td><td>no</td><td>yes</td><td>yes</td></tr> - * <tr><td>1</td><td>mmm-yyyy</td><td>yes</td><td>yes</td></tr> - * <tr><td>2</td><td>dd-mmm-yyyy</td><td>yes</td><td>yes</td></tr> - * <tr><td colspan=4><hr></td></tr> - * <tr><td>3</td><td>no</td><td>yes</td><td>no</td></tr> - * <tr><td>4</td><td>mmm-yyyy</td><td>yes</td><td>no</td></tr> - * <tr><td>5</td><td>dd-mmm-yyyy</td><td>yes</td><td>no</td></tr> - * <tr><td colspan=4><hr></td></tr> - * <tr><td>6</td><td>no</td><td>no</td><td>no</td></tr> - * <tr><td>7</td><td>mmm-yyyy</td><td>no</td><td>no</td></tr> - * <tr><td>8</td><td>dd-mmm-yyyy</td><td>no</td><td>no</td></tr> - * <tr><td colspan=4><hr></td></tr> - * </table> - */ -public final class TrainNewsGroups { - - private TrainNewsGroups() { - } - - public static void main(String[] args) throws IOException { - File base = new File(args[0]); - - Multiset<String> overallCounts = HashMultiset.create(); - - int leakType = 0; - if (args.length > 1) { - leakType = Integer.parseInt(args[1]); - } - - Dictionary newsGroups = new Dictionary(); - - NewsgroupHelper helper = new NewsgroupHelper(); - helper.getEncoder().setProbes(2); - AdaptiveLogisticRegression learningAlgorithm = - new AdaptiveLogisticRegression(20, NewsgroupHelper.FEATURES, new L1()); - learningAlgorithm.setInterval(800); - learningAlgorithm.setAveragingWindow(500); - - List<File> files = new ArrayList<>(); - for (File newsgroup : base.listFiles()) { - if (newsgroup.isDirectory()) { - newsGroups.intern(newsgroup.getName()); - files.addAll(Arrays.asList(newsgroup.listFiles())); - } - } - Collections.shuffle(files); - System.out.println(files.size() + " training files"); - SGDInfo info = new SGDInfo(); - - int k = 0; - - for (File file : files) { - String ng = file.getParentFile().getName(); - int actual = newsGroups.intern(ng); - - Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts); - learningAlgorithm.train(actual, v); - - k++; - State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); - - SGDHelper.analyzeState(info, leakType, k, best); - } - learningAlgorithm.close(); - SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, overallCounts); - System.out.println("exiting main"); - - File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group.model"); - ModelSerializer.writeBinary(modelFile.getAbsolutePath(), - learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); - - List<Integer> counts = new ArrayList<>(); - System.out.println("Word counts"); - for (String count : overallCounts.elementSet()) { - counts.add(overallCounts.count(count)); - } - Collections.sort(counts, Ordering.natural().reverse()); - k = 0; - for (Integer count : counts) { - System.out.println(k + "\t" + count); - k++; - if (k > 1000) { - break; - } - } - } - - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java deleted file mode 100644 index 7a74289..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/ValidateAdaptiveLogistic.java +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.util.Locale; - -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.commons.io.Charsets; -import org.apache.mahout.classifier.ConfusionMatrix; -import org.apache.mahout.classifier.evaluation.Auc; -import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper; -import org.apache.mahout.ep.State; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.SequentialAccessSparseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.stats.OnlineSummarizer; - -/* - * Auc and averageLikelihood are always shown if possible, if the number of target value is more than 2, - * then Auc and entropy matirx are not shown regardless the value of showAuc and showEntropy - * the user passes, because the current implementation does not support them on two value targets. - * */ -public final class ValidateAdaptiveLogistic { - - private static String inputFile; - private static String modelFile; - private static String defaultCategory; - private static boolean showAuc; - private static boolean showScores; - private static boolean showConfusion; - - private ValidateAdaptiveLogistic() { - } - - public static void main(String[] args) throws IOException { - mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); - } - - static void mainToOutput(String[] args, PrintWriter output) throws IOException { - if (parseArgs(args)) { - if (!showAuc && !showConfusion && !showScores) { - showAuc = true; - showConfusion = true; - } - - Auc collector = null; - AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters - .loadFromFile(new File(modelFile)); - CsvRecordFactory csv = lmp.getCsvRecordFactory(); - AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression(); - - if (lmp.getTargetCategories().size() <= 2) { - collector = new Auc(); - } - - OnlineSummarizer slh = new OnlineSummarizer(); - ConfusionMatrix cm = new ConfusionMatrix(lmp.getTargetCategories(), defaultCategory); - - State<Wrapper, CrossFoldLearner> best = lr.getBest(); - if (best == null) { - output.println("AdaptiveLogisticRegression has not be trained probably."); - return; - } - CrossFoldLearner learner = best.getPayload().getLearner(); - - BufferedReader in = TrainLogistic.open(inputFile); - String line = in.readLine(); - csv.firstLine(line); - line = in.readLine(); - if (showScores) { - output.println("\"target\", \"model-output\", \"log-likelihood\", \"average-likelihood\""); - } - while (line != null) { - Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); - //TODO: How to avoid extra target values not shown in the training process. - int target = csv.processLine(line, v); - double likelihood = learner.logLikelihood(target, v); - double score = learner.classifyFull(v).maxValue(); - - slh.add(likelihood); - cm.addInstance(csv.getTargetString(line), csv.getTargetLabel(target)); - - if (showScores) { - output.printf(Locale.ENGLISH, "%8d, %.12f, %.13f, %.13f%n", target, - score, learner.logLikelihood(target, v), slh.getMean()); - } - if (collector != null) { - collector.add(target, score); - } - line = in.readLine(); - } - - output.printf(Locale.ENGLISH,"\nLog-likelihood:"); - output.printf(Locale.ENGLISH, "Min=%.2f, Max=%.2f, Mean=%.2f, Median=%.2f%n", - slh.getMin(), slh.getMax(), slh.getMean(), slh.getMedian()); - - if (collector != null) { - output.printf(Locale.ENGLISH, "%nAUC = %.2f%n", collector.auc()); - } - - if (showConfusion) { - output.printf(Locale.ENGLISH, "%n%s%n%n", cm.toString()); - - if (collector != null) { - Matrix m = collector.entropy(); - output.printf(Locale.ENGLISH, - "Entropy Matrix: [[%.1f, %.1f], [%.1f, %.1f]]%n", m.get(0, 0), - m.get(1, 0), m.get(0, 1), m.get(1, 1)); - } - } - - } - } - - private static boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help") - .withDescription("print this list").create(); - - Option quiet = builder.withLongName("quiet") - .withDescription("be extra quiet").create(); - - Option auc = builder.withLongName("auc").withDescription("print AUC") - .create(); - Option confusion = builder.withLongName("confusion") - .withDescription("print confusion matrix").create(); - - Option scores = builder.withLongName("scores") - .withDescription("print scores").create(); - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option inputFileOption = builder - .withLongName("input") - .withRequired(true) - .withArgument( - argumentBuilder.withName("input").withMaximum(1) - .create()) - .withDescription("where to get validate data").create(); - - Option modelFileOption = builder - .withLongName("model") - .withRequired(true) - .withArgument( - argumentBuilder.withName("model").withMaximum(1) - .create()) - .withDescription("where to get the trained model").create(); - - Option defaultCagetoryOption = builder - .withLongName("defaultCategory") - .withRequired(false) - .withArgument( - argumentBuilder.withName("defaultCategory").withMaximum(1).withDefault("unknown") - .create()) - .withDescription("the default category value to use").create(); - - Group normalArgs = new GroupBuilder().withOption(help) - .withOption(quiet).withOption(auc).withOption(scores) - .withOption(confusion).withOption(inputFileOption) - .withOption(modelFileOption).withOption(defaultCagetoryOption).create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); - CommandLine cmdLine = parser.parseAndHelp(args); - - if (cmdLine == null) { - return false; - } - - inputFile = getStringArgument(cmdLine, inputFileOption); - modelFile = getStringArgument(cmdLine, modelFileOption); - defaultCategory = getStringArgument(cmdLine, defaultCagetoryOption); - showAuc = getBooleanArgument(cmdLine, auc); - showScores = getBooleanArgument(cmdLine, scores); - showConfusion = getBooleanArgument(cmdLine, confusion); - - return true; - } - - private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { - return cmdLine.hasOption(option); - } - - private static String getStringArgument(CommandLine cmdLine, Option inputFile) { - return (String) cmdLine.getValue(inputFile); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java deleted file mode 100644 index ab3c861..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/BankMarketingClassificationMain.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd.bankmarketing; - -import com.google.common.collect.Lists; -import org.apache.mahout.classifier.evaluation.Auc; -import org.apache.mahout.classifier.sgd.L1; -import org.apache.mahout.classifier.sgd.OnlineLogisticRegression; - -import java.util.Collections; -import java.util.List; - -/** - * Uses the SGD classifier on the 'Bank marketing' dataset from UCI. - * - * See http://archive.ics.uci.edu/ml/datasets/Bank+Marketing - * - * Learn when people accept or reject an offer from the bank via telephone based on income, age, education and more. - */ -public class BankMarketingClassificationMain { - - public static final int NUM_CATEGORIES = 2; - - public static void main(String[] args) throws Exception { - List<TelephoneCall> calls = Lists.newArrayList(new TelephoneCallParser("bank-full.csv")); - - double heldOutPercentage = 0.10; - - for (int run = 0; run < 20; run++) { - Collections.shuffle(calls); - int cutoff = (int) (heldOutPercentage * calls.size()); - List<TelephoneCall> test = calls.subList(0, cutoff); - List<TelephoneCall> train = calls.subList(cutoff, calls.size()); - - OnlineLogisticRegression lr = new OnlineLogisticRegression(NUM_CATEGORIES, TelephoneCall.FEATURES, new L1()) - .learningRate(1) - .alpha(1) - .lambda(0.000001) - .stepOffset(10000) - .decayExponent(0.2); - for (int pass = 0; pass < 20; pass++) { - for (TelephoneCall observation : train) { - lr.train(observation.getTarget(), observation.asVector()); - } - if (pass % 5 == 0) { - Auc eval = new Auc(0.5); - for (TelephoneCall testCall : test) { - eval.add(testCall.getTarget(), lr.classifyScalar(testCall.asVector())); - } - System.out.printf("%d, %.4f, %.4f\n", pass, lr.currentLearningRate(), eval.auc()); - } - } - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java deleted file mode 100644 index 728ec20..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCall.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd.bankmarketing; - -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; -import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; -import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; - -import java.util.Iterator; -import java.util.LinkedHashMap; -import java.util.Map; - -public class TelephoneCall { - public static final int FEATURES = 100; - private static final ConstantValueEncoder interceptEncoder = new ConstantValueEncoder("intercept"); - private static final FeatureVectorEncoder featureEncoder = new StaticWordValueEncoder("feature"); - - private RandomAccessSparseVector vector; - - private Map<String, String> fields = new LinkedHashMap<>(); - - public TelephoneCall(Iterable<String> fieldNames, Iterable<String> values) { - vector = new RandomAccessSparseVector(FEATURES); - Iterator<String> value = values.iterator(); - interceptEncoder.addToVector("1", vector); - for (String name : fieldNames) { - String fieldValue = value.next(); - fields.put(name, fieldValue); - - switch (name) { - case "age": { - double v = Double.parseDouble(fieldValue); - featureEncoder.addToVector(name, Math.log(v), vector); - break; - } - case "balance": { - double v; - v = Double.parseDouble(fieldValue); - if (v < -2000) { - v = -2000; - } - featureEncoder.addToVector(name, Math.log(v + 2001) - 8, vector); - break; - } - case "duration": { - double v; - v = Double.parseDouble(fieldValue); - featureEncoder.addToVector(name, Math.log(v + 1) - 5, vector); - break; - } - case "pdays": { - double v; - v = Double.parseDouble(fieldValue); - featureEncoder.addToVector(name, Math.log(v + 2), vector); - break; - } - case "job": - case "marital": - case "education": - case "default": - case "housing": - case "loan": - case "contact": - case "campaign": - case "previous": - case "poutcome": - featureEncoder.addToVector(name + ":" + fieldValue, 1, vector); - break; - case "day": - case "month": - case "y": - // ignore these for vectorizing - break; - default: - throw new IllegalArgumentException(String.format("Bad field name: %s", name)); - } - } - } - - public Vector asVector() { - return vector; - } - - public int getTarget() { - return fields.get("y").equals("no") ? 0 : 1; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java deleted file mode 100644 index 5ef6490..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/bankmarketing/TelephoneCallParser.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd.bankmarketing; - -import com.google.common.base.CharMatcher; -import com.google.common.base.Splitter; -import com.google.common.collect.AbstractIterator; -import com.google.common.io.Resources; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.util.Iterator; - -/** Parses semi-colon separated data as TelephoneCalls */ -public class TelephoneCallParser implements Iterable<TelephoneCall> { - - private final Splitter onSemi = Splitter.on(";").trimResults(CharMatcher.anyOf("\" ;")); - private String resourceName; - - public TelephoneCallParser(String resourceName) throws IOException { - this.resourceName = resourceName; - } - - @Override - public Iterator<TelephoneCall> iterator() { - try { - return new AbstractIterator<TelephoneCall>() { - BufferedReader input = - new BufferedReader(new InputStreamReader(Resources.getResource(resourceName).openStream())); - Iterable<String> fieldNames = onSemi.split(input.readLine()); - - @Override - protected TelephoneCall computeNext() { - try { - String line = input.readLine(); - if (line == null) { - return endOfData(); - } - - return new TelephoneCall(fieldNames, onSemi.split(line)); - } catch (IOException e) { - throw new RuntimeException("Error reading data", e); - } - } - }; - } catch (IOException e) { - throw new RuntimeException("Error reading data", e); - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java b/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java deleted file mode 100644 index a0b845f..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/display/ClustersFilter.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.display; - -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.fs.PathFilter; - -final class ClustersFilter implements PathFilter { - - @Override - public boolean accept(Path path) { - String pathString = path.toString(); - return pathString.contains("/clusters-"); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java deleted file mode 100644 index 50dba99..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java +++ /dev/null @@ -1,88 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.display; - -import java.awt.BasicStroke; -import java.awt.Color; -import java.awt.Graphics; -import java.awt.Graphics2D; -import java.util.List; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.canopy.CanopyDriver; -import org.apache.mahout.common.HadoopUtil; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.common.distance.ManhattanDistanceMeasure; -import org.apache.mahout.math.DenseVector; - -/** - * Java desktop graphics class that runs canopy clustering and displays the results. - * This class generates random data and clusters it. - */ -@Deprecated -public class DisplayCanopy extends DisplayClustering { - - DisplayCanopy() { - initialize(); - this.setTitle("Canopy Clusters (>" + (int) (significance * 100) + "% of population)"); - } - - @Override - public void paint(Graphics g) { - plotSampleData((Graphics2D) g); - plotClusters((Graphics2D) g); - } - - protected static void plotClusters(Graphics2D g2) { - int cx = CLUSTERS.size() - 1; - for (List<Cluster> clusters : CLUSTERS) { - for (Cluster cluster : clusters) { - if (isSignificant(cluster)) { - g2.setStroke(new BasicStroke(1)); - g2.setColor(Color.BLUE); - double[] t1 = {T1, T1}; - plotEllipse(g2, cluster.getCenter(), new DenseVector(t1)); - double[] t2 = {T2, T2}; - plotEllipse(g2, cluster.getCenter(), new DenseVector(t2)); - g2.setColor(COLORS[Math.min(DisplayClustering.COLORS.length - 1, cx)]); - g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1)); - plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3)); - } - } - cx--; - } - } - - public static void main(String[] args) throws Exception { - Path samples = new Path("samples"); - Path output = new Path("output"); - Configuration conf = new Configuration(); - HadoopUtil.delete(conf, samples); - HadoopUtil.delete(conf, output); - RandomUtils.useTestSeed(); - generateSamples(); - writeSampleData(samples); - CanopyDriver.buildClusters(conf, samples, output, new ManhattanDistanceMeasure(), T1, T2, 0, true); - loadClustersWritable(output); - - new DisplayCanopy(); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java deleted file mode 100644 index ad85c6a..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java +++ /dev/null @@ -1,374 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.display; - -import java.awt.*; -import java.awt.event.WindowAdapter; -import java.awt.event.WindowEvent; -import java.awt.geom.AffineTransform; -import java.awt.geom.Ellipse2D; -import java.awt.geom.Rectangle2D; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileStatus; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.IntWritable; -import org.apache.hadoop.io.SequenceFile; -import org.apache.hadoop.io.Text; -import org.apache.mahout.clustering.AbstractCluster; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.UncommonDistributions; -import org.apache.mahout.clustering.classify.WeightedVectorWritable; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; -import org.apache.mahout.math.DenseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class DisplayClustering extends Frame { - - private static final Logger log = LoggerFactory.getLogger(DisplayClustering.class); - - protected static final int DS = 72; // default scale = 72 pixels per inch - - protected static final int SIZE = 8; // screen size in inches - - private static final Collection<Vector> SAMPLE_PARAMS = new ArrayList<>(); - - protected static final List<VectorWritable> SAMPLE_DATA = new ArrayList<>(); - - protected static final List<List<Cluster>> CLUSTERS = new ArrayList<>(); - - static final Color[] COLORS = { Color.red, Color.orange, Color.yellow, Color.green, Color.blue, Color.magenta, - Color.lightGray }; - - protected static final double T1 = 3.0; - - protected static final double T2 = 2.8; - - static double significance = 0.05; - - protected static int res; // screen resolution - - public DisplayClustering() { - initialize(); - this.setTitle("Sample Data"); - } - - public void initialize() { - // Get screen resolution - res = Toolkit.getDefaultToolkit().getScreenResolution(); - - // Set Frame size in inches - this.setSize(SIZE * res, SIZE * res); - this.setVisible(true); - this.setTitle("Asymmetric Sample Data"); - - // Window listener to terminate program. - this.addWindowListener(new WindowAdapter() { - @Override - public void windowClosing(WindowEvent e) { - System.exit(0); - } - }); - } - - public static void main(String[] args) throws Exception { - RandomUtils.useTestSeed(); - generateSamples(); - new DisplayClustering(); - } - - // Override the paint() method - @Override - public void paint(Graphics g) { - Graphics2D g2 = (Graphics2D) g; - plotSampleData(g2); - plotSampleParameters(g2); - plotClusters(g2); - } - - protected static void plotClusters(Graphics2D g2) { - int cx = CLUSTERS.size() - 1; - for (List<Cluster> clusters : CLUSTERS) { - g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1)); - g2.setColor(COLORS[Math.min(COLORS.length - 1, cx--)]); - for (Cluster cluster : clusters) { - plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3)); - } - } - } - - protected static void plotSampleParameters(Graphics2D g2) { - Vector v = new DenseVector(2); - Vector dv = new DenseVector(2); - g2.setColor(Color.RED); - for (Vector param : SAMPLE_PARAMS) { - v.set(0, param.get(0)); - v.set(1, param.get(1)); - dv.set(0, param.get(2) * 3); - dv.set(1, param.get(3) * 3); - plotEllipse(g2, v, dv); - } - } - - protected static void plotSampleData(Graphics2D g2) { - double sx = (double) res / DS; - g2.setTransform(AffineTransform.getScaleInstance(sx, sx)); - - // plot the axes - g2.setColor(Color.BLACK); - Vector dv = new DenseVector(2).assign(SIZE / 2.0); - plotRectangle(g2, new DenseVector(2).assign(2), dv); - plotRectangle(g2, new DenseVector(2).assign(-2), dv); - - // plot the sample data - g2.setColor(Color.DARK_GRAY); - dv.assign(0.03); - for (VectorWritable v : SAMPLE_DATA) { - plotRectangle(g2, v.get(), dv); - } - } - - /** - * This method plots points and colors them according to their cluster - * membership, rather than drawing ellipses. - * - * As of commit, this method is used only by K-means spectral clustering. - * Since the cluster assignments are set within the eigenspace of the data, it - * is not inherent that the original data cluster as they would in K-means: - * that is, as symmetric gaussian mixtures. - * - * Since Spectral K-Means uses K-Means to cluster the eigenspace data, the raw - * output is not directly usable. Rather, the cluster assignments from the raw - * output need to be transferred back to the original data. As such, this - * method will read the SequenceFile cluster results of K-means and transfer - * the cluster assignments to the original data, coloring them appropriately. - * - * @param g2 - * @param data - */ - protected static void plotClusteredSampleData(Graphics2D g2, Path data) { - double sx = (double) res / DS; - g2.setTransform(AffineTransform.getScaleInstance(sx, sx)); - - g2.setColor(Color.BLACK); - Vector dv = new DenseVector(2).assign(SIZE / 2.0); - plotRectangle(g2, new DenseVector(2).assign(2), dv); - plotRectangle(g2, new DenseVector(2).assign(-2), dv); - - // plot the sample data, colored according to the cluster they belong to - dv.assign(0.03); - - Path clusteredPointsPath = new Path(data, "clusteredPoints"); - Path inputPath = new Path(clusteredPointsPath, "part-m-00000"); - Map<Integer,Color> colors = new HashMap<>(); - int point = 0; - for (Pair<IntWritable,WeightedVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedVectorWritable>( - inputPath, new Configuration())) { - int clusterId = record.getFirst().get(); - VectorWritable v = SAMPLE_DATA.get(point++); - Integer key = clusterId; - if (!colors.containsKey(key)) { - colors.put(key, COLORS[Math.min(COLORS.length - 1, colors.size())]); - } - plotClusteredRectangle(g2, v.get(), dv, colors.get(key)); - } - } - - /** - * Identical to plotRectangle(), but with the option of setting the color of - * the rectangle's stroke. - * - * NOTE: This should probably be refactored with plotRectangle() since most of - * the code here is direct copy/paste from that method. - * - * @param g2 - * A Graphics2D context. - * @param v - * A vector for the rectangle's center. - * @param dv - * A vector for the rectangle's dimensions. - * @param color - * The color of the rectangle's stroke. - */ - protected static void plotClusteredRectangle(Graphics2D g2, Vector v, Vector dv, Color color) { - double[] flip = {1, -1}; - Vector v2 = v.times(new DenseVector(flip)); - v2 = v2.minus(dv.divide(2)); - int h = SIZE / 2; - double x = v2.get(0) + h; - double y = v2.get(1) + h; - - g2.setStroke(new BasicStroke(1)); - g2.setColor(color); - g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS)); - } - - /** - * Draw a rectangle on the graphics context - * - * @param g2 - * a Graphics2D context - * @param v - * a Vector of rectangle center - * @param dv - * a Vector of rectangle dimensions - */ - protected static void plotRectangle(Graphics2D g2, Vector v, Vector dv) { - double[] flip = {1, -1}; - Vector v2 = v.times(new DenseVector(flip)); - v2 = v2.minus(dv.divide(2)); - int h = SIZE / 2; - double x = v2.get(0) + h; - double y = v2.get(1) + h; - g2.draw(new Rectangle2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS)); - } - - /** - * Draw an ellipse on the graphics context - * - * @param g2 - * a Graphics2D context - * @param v - * a Vector of ellipse center - * @param dv - * a Vector of ellipse dimensions - */ - protected static void plotEllipse(Graphics2D g2, Vector v, Vector dv) { - double[] flip = {1, -1}; - Vector v2 = v.times(new DenseVector(flip)); - v2 = v2.minus(dv.divide(2)); - int h = SIZE / 2; - double x = v2.get(0) + h; - double y = v2.get(1) + h; - g2.draw(new Ellipse2D.Double(x * DS, y * DS, dv.get(0) * DS, dv.get(1) * DS)); - } - - protected static void generateSamples() { - generateSamples(500, 1, 1, 3); - generateSamples(300, 1, 0, 0.5); - generateSamples(300, 0, 2, 0.1); - } - - protected static void generate2dSamples() { - generate2dSamples(500, 1, 1, 3, 1); - generate2dSamples(300, 1, 0, 0.5, 1); - generate2dSamples(300, 0, 2, 0.1, 0.5); - } - - /** - * Generate random samples and add them to the sampleData - * - * @param num - * int number of samples to generate - * @param mx - * double x-value of the sample mean - * @param my - * double y-value of the sample mean - * @param sd - * double standard deviation of the samples - */ - protected static void generateSamples(int num, double mx, double my, double sd) { - double[] params = {mx, my, sd, sd}; - SAMPLE_PARAMS.add(new DenseVector(params)); - log.info("Generating {} samples m=[{}, {}] sd={}", num, mx, my, sd); - for (int i = 0; i < num; i++) { - SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sd), - UncommonDistributions.rNorm(my, sd)}))); - } - } - - protected static void writeSampleData(Path output) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(output.toUri(), conf); - - try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, output, Text.class, VectorWritable.class)) { - int i = 0; - for (VectorWritable vw : SAMPLE_DATA) { - writer.append(new Text("sample_" + i++), vw); - } - } - } - - protected static List<Cluster> readClustersWritable(Path clustersIn) { - List<Cluster> clusters = new ArrayList<>(); - Configuration conf = new Configuration(); - for (ClusterWritable value : new SequenceFileDirValueIterable<ClusterWritable>(clustersIn, PathType.LIST, - PathFilters.logsCRCFilter(), conf)) { - Cluster cluster = value.getValue(); - log.info( - "Reading Cluster:{} center:{} numPoints:{} radius:{}", - cluster.getId(), AbstractCluster.formatVector(cluster.getCenter(), null), - cluster.getNumObservations(), AbstractCluster.formatVector(cluster.getRadius(), null)); - clusters.add(cluster); - } - return clusters; - } - - protected static void loadClustersWritable(Path output) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(output.toUri(), conf); - for (FileStatus s : fs.listStatus(output, new ClustersFilter())) { - List<Cluster> clusters = readClustersWritable(s.getPath()); - CLUSTERS.add(clusters); - } - } - - /** - * Generate random samples and add them to the sampleData - * - * @param num - * int number of samples to generate - * @param mx - * double x-value of the sample mean - * @param my - * double y-value of the sample mean - * @param sdx - * double x-value standard deviation of the samples - * @param sdy - * double y-value standard deviation of the samples - */ - protected static void generate2dSamples(int num, double mx, double my, double sdx, double sdy) { - double[] params = {mx, my, sdx, sdy}; - SAMPLE_PARAMS.add(new DenseVector(params)); - log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy); - for (int i = 0; i < num; i++) { - SAMPLE_DATA.add(new VectorWritable(new DenseVector(new double[] {UncommonDistributions.rNorm(mx, sdx), - UncommonDistributions.rNorm(my, sdy)}))); - } - } - - protected static boolean isSignificant(Cluster cluster) { - return (double) cluster.getNumObservations() / SAMPLE_DATA.size() > significance; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java deleted file mode 100644 index f8ce7c7..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java +++ /dev/null @@ -1,110 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.display; - -import java.awt.Graphics; -import java.awt.Graphics2D; -import java.io.IOException; -import java.util.Collection; -import java.util.List; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.classify.ClusterClassifier; -import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver; -import org.apache.mahout.clustering.fuzzykmeans.SoftCluster; -import org.apache.mahout.clustering.iterator.ClusterIterator; -import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy; -import org.apache.mahout.clustering.kmeans.RandomSeedGenerator; -import org.apache.mahout.common.HadoopUtil; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.distance.ManhattanDistanceMeasure; -import org.apache.mahout.math.Vector; - -import com.google.common.collect.Lists; - -public class DisplayFuzzyKMeans extends DisplayClustering { - - DisplayFuzzyKMeans() { - initialize(); - this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); - } - - // Override the paint() method - @Override - public void paint(Graphics g) { - plotSampleData((Graphics2D) g); - plotClusters((Graphics2D) g); - } - - public static void main(String[] args) throws Exception { - DistanceMeasure measure = new ManhattanDistanceMeasure(); - - Path samples = new Path("samples"); - Path output = new Path("output"); - Configuration conf = new Configuration(); - HadoopUtil.delete(conf, output); - HadoopUtil.delete(conf, samples); - RandomUtils.useTestSeed(); - DisplayClustering.generateSamples(); - writeSampleData(samples); - boolean runClusterer = true; - int maxIterations = 10; - float threshold = 0.001F; - float m = 1.1F; - if (runClusterer) { - runSequentialFuzzyKClusterer(conf, samples, output, measure, maxIterations, m, threshold); - } else { - int numClusters = 3; - runSequentialFuzzyKClassifier(conf, samples, output, measure, numClusters, maxIterations, m, threshold); - } - new DisplayFuzzyKMeans(); - } - - private static void runSequentialFuzzyKClassifier(Configuration conf, Path samples, Path output, - DistanceMeasure measure, int numClusters, int maxIterations, float m, double threshold) throws IOException { - Collection<Vector> points = Lists.newArrayList(); - for (int i = 0; i < numClusters; i++) { - points.add(SAMPLE_DATA.get(i).get()); - } - List<Cluster> initialClusters = Lists.newArrayList(); - int id = 0; - for (Vector point : points) { - initialClusters.add(new SoftCluster(point, id++, measure)); - } - ClusterClassifier prior = new ClusterClassifier(initialClusters, new FuzzyKMeansClusteringPolicy(m, threshold)); - Path priorPath = new Path(output, "classifier-0"); - prior.writeToSeqFiles(priorPath); - - ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations); - loadClustersWritable(output); - } - - private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path output, - DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException, - ClassNotFoundException, InterruptedException { - Path clustersIn = new Path(output, "random-seeds"); - RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure); - FuzzyKMeansDriver.run(samples, clustersIn, output, threshold, maxIterations, m, true, true, threshold, - true); - - loadClustersWritable(output); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java deleted file mode 100644 index 336d69e..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.display; - -import java.awt.Graphics; -import java.awt.Graphics2D; -import java.io.IOException; -import java.util.Collection; -import java.util.List; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.Cluster; -import org.apache.mahout.clustering.classify.ClusterClassifier; -import org.apache.mahout.clustering.iterator.ClusterIterator; -import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy; -import org.apache.mahout.clustering.kmeans.KMeansDriver; -import org.apache.mahout.clustering.kmeans.RandomSeedGenerator; -import org.apache.mahout.common.HadoopUtil; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.distance.ManhattanDistanceMeasure; -import org.apache.mahout.math.Vector; - -import com.google.common.collect.Lists; - -public class DisplayKMeans extends DisplayClustering { - - DisplayKMeans() { - initialize(); - this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); - } - - public static void main(String[] args) throws Exception { - DistanceMeasure measure = new ManhattanDistanceMeasure(); - Path samples = new Path("samples"); - Path output = new Path("output"); - Configuration conf = new Configuration(); - HadoopUtil.delete(conf, samples); - HadoopUtil.delete(conf, output); - - RandomUtils.useTestSeed(); - generateSamples(); - writeSampleData(samples); - boolean runClusterer = true; - double convergenceDelta = 0.001; - int numClusters = 3; - int maxIterations = 10; - if (runClusterer) { - runSequentialKMeansClusterer(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta); - } else { - runSequentialKMeansClassifier(conf, samples, output, measure, numClusters, maxIterations, convergenceDelta); - } - new DisplayKMeans(); - } - - private static void runSequentialKMeansClassifier(Configuration conf, Path samples, Path output, - DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta) throws IOException { - Collection<Vector> points = Lists.newArrayList(); - for (int i = 0; i < numClusters; i++) { - points.add(SAMPLE_DATA.get(i).get()); - } - List<Cluster> initialClusters = Lists.newArrayList(); - int id = 0; - for (Vector point : points) { - initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure)); - } - ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta)); - Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR); - prior.writeToSeqFiles(priorPath); - - ClusterIterator.iterateSeq(conf, samples, priorPath, output, maxIterations); - loadClustersWritable(output); - } - - private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path output, - DistanceMeasure measure, int numClusters, int maxIterations, double convergenceDelta) - throws IOException, InterruptedException, ClassNotFoundException { - Path clustersIn = new Path(output, "random-seeds"); - RandomSeedGenerator.buildRandom(conf, samples, clustersIn, numClusters, measure); - KMeansDriver.run(samples, clustersIn, output, convergenceDelta, maxIterations, true, 0.0, true); - loadClustersWritable(output); - } - - // Override the paint() method - @Override - public void paint(Graphics g) { - plotSampleData((Graphics2D) g); - plotClusters((Graphics2D) g); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java b/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java deleted file mode 100644 index 2b70749..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/display/DisplaySpectralKMeans.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.display; - -import java.awt.Graphics; -import java.awt.Graphics2D; -import java.io.BufferedWriter; -import java.io.FileWriter; -import java.io.Writer; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver; -import org.apache.mahout.common.HadoopUtil; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.distance.ManhattanDistanceMeasure; - -public class DisplaySpectralKMeans extends DisplayClustering { - - protected static final String SAMPLES = "samples"; - protected static final String OUTPUT = "output"; - protected static final String TEMP = "tmp"; - protected static final String AFFINITIES = "affinities"; - - DisplaySpectralKMeans() { - initialize(); - setTitle("Spectral k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); - } - - public static void main(String[] args) throws Exception { - DistanceMeasure measure = new ManhattanDistanceMeasure(); - Path samples = new Path(SAMPLES); - Path output = new Path(OUTPUT); - Path tempDir = new Path(TEMP); - Configuration conf = new Configuration(); - HadoopUtil.delete(conf, samples); - HadoopUtil.delete(conf, output); - - RandomUtils.useTestSeed(); - DisplayClustering.generateSamples(); - writeSampleData(samples); - Path affinities = new Path(output, AFFINITIES); - FileSystem fs = FileSystem.get(output.toUri(), conf); - if (!fs.exists(output)) { - fs.mkdirs(output); - } - - try (Writer writer = new BufferedWriter(new FileWriter(affinities.toString()))){ - for (int i = 0; i < SAMPLE_DATA.size(); i++) { - for (int j = 0; j < SAMPLE_DATA.size(); j++) { - writer.write(i + "," + j + ',' + measure.distance(SAMPLE_DATA.get(i).get(), - SAMPLE_DATA.get(j).get()) + '\n'); - } - } - } - - int maxIter = 10; - double convergenceDelta = 0.001; - SpectralKMeansDriver.run(new Configuration(), affinities, output, SAMPLE_DATA.size(), 3, measure, - convergenceDelta, maxIter, tempDir); - new DisplaySpectralKMeans(); - } - - @Override - public void paint(Graphics g) { - plotClusteredSampleData((Graphics2D) g, new Path(new Path(OUTPUT), "kmeans_out")); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/display/README.txt ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/display/README.txt b/examples/src/main/java/org/apache/mahout/clustering/display/README.txt deleted file mode 100644 index 470c16c..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/display/README.txt +++ /dev/null @@ -1,22 +0,0 @@ -The following classes can be run without parameters to generate a sample data set and -run the reference clustering implementations over them: - -DisplayClustering - generates 1000 samples from three, symmetric distributions. This is the same - data set that is used by the following clustering programs. It displays the points on a screen - and superimposes the model parameters that were used to generate the points. You can edit the - generateSamples() method to change the sample points used by these programs. - - * DisplayCanopy - uses Canopy clustering - * DisplayKMeans - uses k-Means clustering - * DisplayFuzzyKMeans - uses Fuzzy k-Means clustering - - * NOTE: some of these programs display the sample points and then superimpose all of the clusters - from each iteration. The last iteration's clusters are in bold red and the previous several are - colored (orange, yellow, green, blue, violet) in order after which all earlier clusters are in - light grey. This helps to visualize how the clusters converge upon a solution over multiple - iterations. - * NOTE: by changing the parameter values (k, ALPHA_0, numIterations) and the display SIGNIFICANCE - you can obtain different results. - - - \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java b/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java deleted file mode 100644 index c29cbc4..0000000 --- a/examples/src/main/java/org/apache/mahout/clustering/streaming/tools/ClusterQualitySummarizer.java +++ /dev/null @@ -1,279 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.clustering.streaming.tools; - -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.PrintWriter; -import java.util.List; - -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; -import com.google.common.io.Closeables; -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.clustering.iterator.ClusterWritable; -import org.apache.mahout.clustering.ClusteringUtils; -import org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.distance.DistanceMeasure; -import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; -import org.apache.mahout.math.Centroid; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.VectorWritable; -import org.apache.mahout.math.stats.OnlineSummarizer; - -public class ClusterQualitySummarizer extends AbstractJob { - private String outputFile; - - private PrintWriter fileOut; - - private String trainFile; - private String testFile; - private String centroidFile; - private String centroidCompareFile; - private boolean mahoutKMeansFormat; - private boolean mahoutKMeansFormatCompare; - - private DistanceMeasure distanceMeasure = new SquaredEuclideanDistanceMeasure(); - - public void printSummaries(List<OnlineSummarizer> summarizers, String type) { - printSummaries(summarizers, type, fileOut); - } - - public static void printSummaries(List<OnlineSummarizer> summarizers, String type, PrintWriter fileOut) { - double maxDistance = 0; - for (int i = 0; i < summarizers.size(); ++i) { - OnlineSummarizer summarizer = summarizers.get(i); - if (summarizer.getCount() > 1) { - maxDistance = Math.max(maxDistance, summarizer.getMax()); - System.out.printf("Average distance in cluster %d [%d]: %f\n", i, summarizer.getCount(), summarizer.getMean()); - // If there is just one point in the cluster, quartiles cannot be estimated. We'll just assume all the quartiles - // equal the only value. - if (fileOut != null) { - fileOut.printf("%d,%f,%f,%f,%f,%f,%f,%f,%d,%s\n", i, summarizer.getMean(), - summarizer.getSD(), - summarizer.getQuartile(0), - summarizer.getQuartile(1), - summarizer.getQuartile(2), - summarizer.getQuartile(3), - summarizer.getQuartile(4), summarizer.getCount(), type); - } - } else { - System.out.printf("Cluster %d is has %d data point. Need atleast 2 data points in a cluster for" + - " OnlineSummarizer.\n", i, summarizer.getCount()); - } - } - System.out.printf("Num clusters: %d; maxDistance: %f\n", summarizers.size(), maxDistance); - } - - public int run(String[] args) throws IOException { - if (!parseArgs(args)) { - return -1; - } - - Configuration conf = new Configuration(); - try { - fileOut = new PrintWriter(new FileOutputStream(outputFile)); - fileOut.printf("cluster,distance.mean,distance.sd,distance.q0,distance.q1,distance.q2,distance.q3," - + "distance.q4,count,is.train\n"); - - // Reading in the centroids (both pairs, if they exist). - List<Centroid> centroids; - List<Centroid> centroidsCompare = null; - if (mahoutKMeansFormat) { - SequenceFileDirValueIterable<ClusterWritable> clusterIterable = - new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf); - centroids = Lists.newArrayList(IOUtils.getCentroidsFromClusterWritableIterable(clusterIterable)); - } else { - SequenceFileDirValueIterable<CentroidWritable> centroidIterable = - new SequenceFileDirValueIterable<>(new Path(centroidFile), PathType.GLOB, conf); - centroids = Lists.newArrayList(IOUtils.getCentroidsFromCentroidWritableIterable(centroidIterable)); - } - - if (centroidCompareFile != null) { - if (mahoutKMeansFormatCompare) { - SequenceFileDirValueIterable<ClusterWritable> clusterCompareIterable = - new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf); - centroidsCompare = Lists.newArrayList( - IOUtils.getCentroidsFromClusterWritableIterable(clusterCompareIterable)); - } else { - SequenceFileDirValueIterable<CentroidWritable> centroidCompareIterable = - new SequenceFileDirValueIterable<>(new Path(centroidCompareFile), PathType.GLOB, conf); - centroidsCompare = Lists.newArrayList( - IOUtils.getCentroidsFromCentroidWritableIterable(centroidCompareIterable)); - } - } - - // Reading in the "training" set. - SequenceFileDirValueIterable<VectorWritable> trainIterable = - new SequenceFileDirValueIterable<>(new Path(trainFile), PathType.GLOB, conf); - Iterable<Vector> trainDatapoints = IOUtils.getVectorsFromVectorWritableIterable(trainIterable); - Iterable<Vector> datapoints = trainDatapoints; - - printSummaries(ClusteringUtils.summarizeClusterDistances(trainDatapoints, centroids, - new SquaredEuclideanDistanceMeasure()), "train"); - - // Also adding in the "test" set. - if (testFile != null) { - SequenceFileDirValueIterable<VectorWritable> testIterable = - new SequenceFileDirValueIterable<>(new Path(testFile), PathType.GLOB, conf); - Iterable<Vector> testDatapoints = IOUtils.getVectorsFromVectorWritableIterable(testIterable); - - printSummaries(ClusteringUtils.summarizeClusterDistances(testDatapoints, centroids, - new SquaredEuclideanDistanceMeasure()), "test"); - - datapoints = Iterables.concat(trainDatapoints, testDatapoints); - } - - // At this point, all train/test CSVs have been written. We now compute quality metrics. - List<OnlineSummarizer> summaries = - ClusteringUtils.summarizeClusterDistances(datapoints, centroids, distanceMeasure); - List<OnlineSummarizer> compareSummaries = null; - if (centroidsCompare != null) { - compareSummaries = ClusteringUtils.summarizeClusterDistances(datapoints, centroidsCompare, distanceMeasure); - } - System.out.printf("[Dunn Index] First: %f", ClusteringUtils.dunnIndex(centroids, distanceMeasure, summaries)); - if (compareSummaries != null) { - System.out.printf(" Second: %f\n", ClusteringUtils.dunnIndex(centroidsCompare, distanceMeasure, compareSummaries)); - } else { - System.out.printf("\n"); - } - System.out.printf("[Davies-Bouldin Index] First: %f", - ClusteringUtils.daviesBouldinIndex(centroids, distanceMeasure, summaries)); - if (compareSummaries != null) { - System.out.printf(" Second: %f\n", - ClusteringUtils.daviesBouldinIndex(centroidsCompare, distanceMeasure, compareSummaries)); - } else { - System.out.printf("\n"); - } - } catch (IOException e) { - System.out.println(e.getMessage()); - } finally { - Closeables.close(fileOut, false); - } - return 0; - } - - private boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help").withDescription("print this list").create(); - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option inputFileOption = builder.withLongName("input") - .withShortName("i") - .withRequired(true) - .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) - .withDescription("where to get seq files with the vectors (training set)") - .create(); - - Option testInputFileOption = builder.withLongName("testInput") - .withShortName("itest") - .withArgument(argumentBuilder.withName("testInput").withMaximum(1).create()) - .withDescription("where to get seq files with the vectors (test set)") - .create(); - - Option centroidsFileOption = builder.withLongName("centroids") - .withShortName("c") - .withRequired(true) - .withArgument(argumentBuilder.withName("centroids").withMaximum(1).create()) - .withDescription("where to get seq files with the centroids (from Mahout KMeans or StreamingKMeansDriver)") - .create(); - - Option centroidsCompareFileOption = builder.withLongName("centroidsCompare") - .withShortName("cc") - .withRequired(false) - .withArgument(argumentBuilder.withName("centroidsCompare").withMaximum(1).create()) - .withDescription("where to get seq files with the second set of centroids (from Mahout KMeans or " - + "StreamingKMeansDriver)") - .create(); - - Option outputFileOption = builder.withLongName("output") - .withShortName("o") - .withRequired(true) - .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) - .withDescription("where to dump the CSV file with the results") - .create(); - - Option mahoutKMeansFormatOption = builder.withLongName("mahoutkmeansformat") - .withShortName("mkm") - .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs") - .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create()) - .create(); - - Option mahoutKMeansCompareFormatOption = builder.withLongName("mahoutkmeansformatCompare") - .withShortName("mkmc") - .withDescription("if set, read files as (IntWritable, ClusterWritable) pairs") - .withArgument(argumentBuilder.withName("numpoints").withMaximum(1).create()) - .create(); - - Group normalArgs = new GroupBuilder() - .withOption(help) - .withOption(inputFileOption) - .withOption(testInputFileOption) - .withOption(outputFileOption) - .withOption(centroidsFileOption) - .withOption(centroidsCompareFileOption) - .withOption(mahoutKMeansFormatOption) - .withOption(mahoutKMeansCompareFormatOption) - .create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 150)); - - CommandLine cmdLine = parser.parseAndHelp(args); - if (cmdLine == null) { - return false; - } - - trainFile = (String) cmdLine.getValue(inputFileOption); - if (cmdLine.hasOption(testInputFileOption)) { - testFile = (String) cmdLine.getValue(testInputFileOption); - } - centroidFile = (String) cmdLine.getValue(centroidsFileOption); - if (cmdLine.hasOption(centroidsCompareFileOption)) { - centroidCompareFile = (String) cmdLine.getValue(centroidsCompareFileOption); - } - outputFile = (String) cmdLine.getValue(outputFileOption); - if (cmdLine.hasOption(mahoutKMeansFormatOption)) { - mahoutKMeansFormat = true; - } - if (cmdLine.hasOption(mahoutKMeansCompareFormatOption)) { - mahoutKMeansFormatCompare = true; - } - return true; - } - - public static void main(String[] args) throws IOException { - new ClusterQualitySummarizer().run(args); - } -}
