http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java new file mode 100644 index 0000000..b2ce8b1 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.stats.GlobalOnlineAuc; +import org.apache.mahout.math.stats.GroupedOnlineAuc; +import org.apache.mahout.math.stats.OnlineAuc; + +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public class AdaptiveLogisticModelParameters extends LogisticModelParameters { + + private AdaptiveLogisticRegression alr; + private int interval = 800; + private int averageWindow = 500; + private int threads = 4; + private String prior = "L1"; + private double priorOption = Double.NaN; + private String auc = null; + + public AdaptiveLogisticRegression createAdaptiveLogisticRegression() { + + if (alr == null) { + alr = new AdaptiveLogisticRegression(getMaxTargetCategories(), + getNumFeatures(), createPrior(prior, priorOption)); + alr.setInterval(interval); + alr.setAveragingWindow(averageWindow); + alr.setThreadCount(threads); + alr.setAucEvaluator(createAUC(auc)); + } + return alr; + } + + public void checkParameters() { + if (prior != null) { + String priorUppercase = prior.toUpperCase(Locale.ENGLISH).trim(); + if (("TP".equals(priorUppercase) || "EBP".equals(priorUppercase)) && Double.isNaN(priorOption)) { + throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior."); + } + } + } + + private static PriorFunction createPrior(String cmd, double priorOption) { + if (cmd == null) { + return null; + } + if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { + return new L1(); + } + if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { + return new L2(); + } + if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { + return new UniformPrior(); + } + if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { + return new TPrior(priorOption); + } + if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { + return new ElasticBandPrior(priorOption); + } + + return null; + } + + private static OnlineAuc createAUC(String cmd) { + if (cmd == null) { + return null; + } + if ("GLOBAL".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { + return new GlobalOnlineAuc(); + } + if ("GROUPED".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { + return new GroupedOnlineAuc(); + } + return null; + } + + @Override + public void saveTo(OutputStream out) throws IOException { + if (alr != null) { + alr.close(); + } + setTargetCategories(getCsvRecordFactory().getTargetCategories()); + write(new DataOutputStream(out)); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeUTF(getTargetVariable()); + out.writeInt(getTypeMap().size()); + for (Map.Entry<String, String> entry : getTypeMap().entrySet()) { + out.writeUTF(entry.getKey()); + out.writeUTF(entry.getValue()); + } + out.writeInt(getNumFeatures()); + out.writeInt(getMaxTargetCategories()); + out.writeInt(getTargetCategories().size()); + for (String category : getTargetCategories()) { + out.writeUTF(category); + } + + out.writeInt(interval); + out.writeInt(averageWindow); + out.writeInt(threads); + out.writeUTF(prior); + out.writeDouble(priorOption); + out.writeUTF(auc); + + // skip csv + alr.write(out); + } + + @Override + public void readFields(DataInput in) throws IOException { + setTargetVariable(in.readUTF()); + int typeMapSize = in.readInt(); + Map<String, String> typeMap = new HashMap<>(typeMapSize); + for (int i = 0; i < typeMapSize; i++) { + String key = in.readUTF(); + String value = in.readUTF(); + typeMap.put(key, value); + } + setTypeMap(typeMap); + + setNumFeatures(in.readInt()); + setMaxTargetCategories(in.readInt()); + int targetCategoriesSize = in.readInt(); + List<String> targetCategories = new ArrayList<>(targetCategoriesSize); + for (int i = 0; i < targetCategoriesSize; i++) { + targetCategories.add(in.readUTF()); + } + setTargetCategories(targetCategories); + + interval = in.readInt(); + averageWindow = in.readInt(); + threads = in.readInt(); + prior = in.readUTF(); + priorOption = in.readDouble(); + auc = in.readUTF(); + + alr = new AdaptiveLogisticRegression(); + alr.readFields(in); + } + + + private static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException { + AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters(); + result.readFields(new DataInputStream(in)); + return result; + } + + public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException { + try (InputStream input = new FileInputStream(in)) { + return loadFromStream(input); + } + } + + public int getInterval() { + return interval; + } + + public void setInterval(int interval) { + this.interval = interval; + } + + public int getAverageWindow() { + return averageWindow; + } + + public void setAverageWindow(int averageWindow) { + this.averageWindow = averageWindow; + } + + public int getThreads() { + return threads; + } + + public void setThreads(int threads) { + this.threads = threads; + } + + public String getPrior() { + return prior; + } + + public void setPrior(String prior) { + this.prior = prior; + } + + public String getAuc() { + return auc; + } + + public void setAuc(String auc) { + this.auc = auc; + } + + public double getPriorOption() { + return priorOption; + } + + public void setPriorOption(double priorOption) { + this.priorOption = priorOption; + } + + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java new file mode 100644 index 0000000..e762924 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.hadoop.io.Writable; + +/** + * Encapsulates everything we need to know about a model and how it reads and vectorizes its input. + * This encapsulation allows us to coherently save and restore a model from a file. This also + * allows us to keep command line arguments that affect learning in a coherent way. + */ +public class LogisticModelParameters implements Writable { + private String targetVariable; + private Map<String, String> typeMap; + private int numFeatures; + private boolean useBias; + private int maxTargetCategories; + private List<String> targetCategories; + private double lambda; + private double learningRate; + private CsvRecordFactory csv; + private OnlineLogisticRegression lr; + + /** + * Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied + * in here is so that we have access to the list of target categories when it comes time to save + * the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will + * suffice. + * + * @return The CsvRecordFactory. + */ + public CsvRecordFactory getCsvRecordFactory() { + if (csv == null) { + csv = new CsvRecordFactory(getTargetVariable(), getTypeMap()) + .maxTargetValue(getMaxTargetCategories()) + .includeBiasTerm(useBias()); + if (targetCategories != null) { + csv.defineTargetCategories(targetCategories); + } + } + return csv; + } + + /** + * Creates a logistic regression trainer using the parameters collected here. + * + * @return The newly allocated OnlineLogisticRegression object + */ + public OnlineLogisticRegression createRegression() { + if (lr == null) { + lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1()) + .lambda(getLambda()) + .learningRate(getLearningRate()) + .alpha(1 - 1.0e-3); + } + return lr; + } + + /** + * Saves a model to an output stream. + */ + public void saveTo(OutputStream out) throws IOException { + Closeables.close(lr, false); + targetCategories = getCsvRecordFactory().getTargetCategories(); + write(new DataOutputStream(out)); + } + + /** + * Reads a model from a stream. + */ + public static LogisticModelParameters loadFrom(InputStream in) throws IOException { + LogisticModelParameters result = new LogisticModelParameters(); + result.readFields(new DataInputStream(in)); + return result; + } + + /** + * Reads a model from a file. + * @throws IOException If there is an error opening or closing the file. + */ + public static LogisticModelParameters loadFrom(File in) throws IOException { + try (InputStream input = new FileInputStream(in)) { + return loadFrom(input); + } + } + + + @Override + public void write(DataOutput out) throws IOException { + out.writeUTF(targetVariable); + out.writeInt(typeMap.size()); + for (Map.Entry<String,String> entry : typeMap.entrySet()) { + out.writeUTF(entry.getKey()); + out.writeUTF(entry.getValue()); + } + out.writeInt(numFeatures); + out.writeBoolean(useBias); + out.writeInt(maxTargetCategories); + + if (targetCategories == null) { + out.writeInt(0); + } else { + out.writeInt(targetCategories.size()); + for (String category : targetCategories) { + out.writeUTF(category); + } + } + out.writeDouble(lambda); + out.writeDouble(learningRate); + // skip csv + lr.write(out); + } + + @Override + public void readFields(DataInput in) throws IOException { + targetVariable = in.readUTF(); + int typeMapSize = in.readInt(); + typeMap = new HashMap<>(typeMapSize); + for (int i = 0; i < typeMapSize; i++) { + String key = in.readUTF(); + String value = in.readUTF(); + typeMap.put(key, value); + } + numFeatures = in.readInt(); + useBias = in.readBoolean(); + maxTargetCategories = in.readInt(); + int targetCategoriesSize = in.readInt(); + targetCategories = new ArrayList<>(targetCategoriesSize); + for (int i = 0; i < targetCategoriesSize; i++) { + targetCategories.add(in.readUTF()); + } + lambda = in.readDouble(); + learningRate = in.readDouble(); + csv = null; + lr = new OnlineLogisticRegression(); + lr.readFields(in); + } + + /** + * Sets the types of the predictors. This will later be used when reading CSV data. If you don't + * use the CSV data and convert to vectors on your own, you don't need to call this. + * + * @param predictorList The list of variable names. + * @param typeList The list of types in the format preferred by CsvRecordFactory. + */ + public void setTypeMap(Iterable<String> predictorList, List<String> typeList) { + Preconditions.checkArgument(!typeList.isEmpty(), "Must have at least one type specifier"); + typeMap = new HashMap<>(); + Iterator<String> iTypes = typeList.iterator(); + String lastType = null; + for (Object x : predictorList) { + // type list can be short .. we just repeat last spec + if (iTypes.hasNext()) { + lastType = iTypes.next(); + } + typeMap.put(x.toString(), lastType); + } + } + + /** + * Sets the target variable. If you don't use the CSV record factory, then this is irrelevant. + * + * @param targetVariable The name of the target variable. + */ + public void setTargetVariable(String targetVariable) { + this.targetVariable = targetVariable; + } + + /** + * Sets the number of target categories to be considered. + * + * @param maxTargetCategories The number of target categories. + */ + public void setMaxTargetCategories(int maxTargetCategories) { + this.maxTargetCategories = maxTargetCategories; + } + + public void setNumFeatures(int numFeatures) { + this.numFeatures = numFeatures; + } + + public void setTargetCategories(List<String> targetCategories) { + this.targetCategories = targetCategories; + maxTargetCategories = targetCategories.size(); + } + + public List<String> getTargetCategories() { + return this.targetCategories; + } + + public void setUseBias(boolean useBias) { + this.useBias = useBias; + } + + public boolean useBias() { + return useBias; + } + + public String getTargetVariable() { + return targetVariable; + } + + public Map<String, String> getTypeMap() { + return typeMap; + } + + public void setTypeMap(Map<String, String> map) { + this.typeMap = map; + } + + public int getNumFeatures() { + return numFeatures; + } + + public int getMaxTargetCategories() { + return maxTargetCategories; + } + + public double getLambda() { + return lambda; + } + + public void setLambda(double lambda) { + this.lambda = lambda; + } + + public double getLearningRate() { + return learningRate; + } + + public void setLearningRate(double learningRate) { + this.learningRate = learningRate; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java new file mode 100644 index 0000000..3ec6a06 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.Preconditions; + +import java.io.BufferedReader; + +/** + * Uses the same logic as TrainLogistic and RunLogistic for finding an input, but instead + * of processing the input, this class just prints the input to standard out. + */ +public final class PrintResourceOrFile { + + private PrintResourceOrFile() { + } + + public static void main(String[] args) throws Exception { + Preconditions.checkArgument(args.length == 1, "Must have a single argument that names a file or resource."); + try (BufferedReader in = TrainLogistic.open(args[0])){ + String line; + while ((line = in.readLine()) != null) { + System.out.println(line); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java new file mode 100644 index 0000000..678a8f5 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper; +import org.apache.mahout.ep.State; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.HashMap; +import java.util.Map; + +public final class RunAdaptiveLogistic { + + private static String inputFile; + private static String modelFile; + private static String outputFile; + private static String idColumn; + private static boolean maxScoreOnly; + + private RunAdaptiveLogistic() { + } + + public static void main(String[] args) throws Exception { + mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + + static void mainToOutput(String[] args, PrintWriter output) throws Exception { + if (!parseArgs(args)) { + return; + } + AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters + .loadFromFile(new File(modelFile)); + + CsvRecordFactory csv = lmp.getCsvRecordFactory(); + csv.setIdName(idColumn); + + AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression(); + + State<Wrapper, CrossFoldLearner> best = lr.getBest(); + if (best == null) { + output.println("AdaptiveLogisticRegression has not be trained probably."); + return; + } + CrossFoldLearner learner = best.getPayload().getLearner(); + + BufferedReader in = TrainAdaptiveLogistic.open(inputFile); + int k = 0; + + try (BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile), + Charsets.UTF_8))) { + out.write(idColumn + ",target,score"); + out.newLine(); + + String line = in.readLine(); + csv.firstLine(line); + line = in.readLine(); + Map<String, Double> results = new HashMap<>(); + while (line != null) { + Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); + csv.processLine(line, v, false); + Vector scores = learner.classifyFull(v); + results.clear(); + if (maxScoreOnly) { + results.put(csv.getTargetLabel(scores.maxValueIndex()), + scores.maxValue()); + } else { + for (int i = 0; i < scores.size(); i++) { + results.put(csv.getTargetLabel(i), scores.get(i)); + } + } + + for (Map.Entry<String, Double> entry : results.entrySet()) { + out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue()); + out.newLine(); + } + k++; + if (k % 100 == 0) { + output.println(k + " records processed"); + } + line = in.readLine(); + } + out.flush(); + } + output.println(k + " records processed totally."); + } + + private static boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help") + .withDescription("print this list").create(); + + Option quiet = builder.withLongName("quiet") + .withDescription("be extra quiet").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFileOption = builder + .withLongName("input") + .withRequired(true) + .withArgument( + argumentBuilder.withName("input").withMaximum(1) + .create()) + .withDescription("where to get training data").create(); + + Option modelFileOption = builder + .withLongName("model") + .withRequired(true) + .withArgument( + argumentBuilder.withName("model").withMaximum(1) + .create()) + .withDescription("where to get the trained model").create(); + + Option outputFileOption = builder + .withLongName("output") + .withRequired(true) + .withDescription("the file path to output scores") + .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) + .create(); + + Option idColumnOption = builder + .withLongName("idcolumn") + .withRequired(true) + .withDescription("the name of the id column for each record") + .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create()) + .create(); + + Option maxScoreOnlyOption = builder + .withLongName("maxscoreonly") + .withDescription("only output the target label with max scores") + .create(); + + Group normalArgs = new GroupBuilder() + .withOption(help).withOption(quiet) + .withOption(inputFileOption).withOption(modelFileOption) + .withOption(outputFileOption).withOption(idColumnOption) + .withOption(maxScoreOnlyOption) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + inputFile = getStringArgument(cmdLine, inputFileOption); + modelFile = getStringArgument(cmdLine, modelFileOption); + outputFile = getStringArgument(cmdLine, outputFileOption); + idColumn = getStringArgument(cmdLine, idColumnOption); + maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption); + return true; + } + + private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { + return cmdLine.hasOption(option); + } + + private static String getStringArgument(CommandLine cmdLine, Option inputFile) { + return (String) cmdLine.getValue(inputFile); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java new file mode 100644 index 0000000..2d57016 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.mahout.classifier.evaluation.Auc; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; + +import java.io.BufferedReader; +import java.io.File; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.Locale; + +public final class RunLogistic { + + private static String inputFile; + private static String modelFile; + private static boolean showAuc; + private static boolean showScores; + private static boolean showConfusion; + + private RunLogistic() { + } + + public static void main(String[] args) throws Exception { + mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + + static void mainToOutput(String[] args, PrintWriter output) throws Exception { + if (parseArgs(args)) { + if (!showAuc && !showConfusion && !showScores) { + showAuc = true; + showConfusion = true; + } + + Auc collector = new Auc(); + LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile)); + + CsvRecordFactory csv = lmp.getCsvRecordFactory(); + OnlineLogisticRegression lr = lmp.createRegression(); + BufferedReader in = TrainLogistic.open(inputFile); + String line = in.readLine(); + csv.firstLine(line); + line = in.readLine(); + if (showScores) { + output.println("\"target\",\"model-output\",\"log-likelihood\""); + } + while (line != null) { + Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); + int target = csv.processLine(line, v); + + double score = lr.classifyScalar(v); + if (showScores) { + output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v)); + } + collector.add(target, score); + line = in.readLine(); + } + + if (showAuc) { + output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc()); + } + if (showConfusion) { + Matrix m = collector.confusion(); + output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", + m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); + m = collector.entropy(); + output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", + m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); + } + } + } + + private static boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help").withDescription("print this list").create(); + + Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create(); + + Option auc = builder.withLongName("auc").withDescription("print AUC").create(); + Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create(); + + Option scores = builder.withLongName("scores").withDescription("print scores").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFileOption = builder.withLongName("input") + .withRequired(true) + .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) + .withDescription("where to get training data") + .create(); + + Option modelFileOption = builder.withLongName("model") + .withRequired(true) + .withArgument(argumentBuilder.withName("model").withMaximum(1).create()) + .withDescription("where to get a model") + .create(); + + Group normalArgs = new GroupBuilder() + .withOption(help) + .withOption(quiet) + .withOption(auc) + .withOption(scores) + .withOption(confusion) + .withOption(inputFileOption) + .withOption(modelFileOption) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + inputFile = getStringArgument(cmdLine, inputFileOption); + modelFile = getStringArgument(cmdLine, modelFileOption); + showAuc = getBooleanArgument(cmdLine, auc); + showScores = getBooleanArgument(cmdLine, scores); + showConfusion = getBooleanArgument(cmdLine, confusion); + + return true; + } + + private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { + return cmdLine.hasOption(option); + } + + private static String getStringArgument(CommandLine cmdLine, Option inputFile) { + return (String) cmdLine.getValue(inputFile); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java new file mode 100644 index 0000000..c657803 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java @@ -0,0 +1,151 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.Multiset; +import org.apache.mahout.classifier.NewsgroupHelper; +import org.apache.mahout.ep.State; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.vectorizer.encoders.Dictionary; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.TreeMap; + +public final class SGDHelper { + + private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"}; + + private SGDHelper() { + } + + public static void dissect(int leakType, + Dictionary dictionary, + AdaptiveLogisticRegression learningAlgorithm, + Iterable<File> files, Multiset<String> overallCounts) throws IOException { + CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner(); + model.close(); + + Map<String, Set<Integer>> traceDictionary = new TreeMap<>(); + ModelDissector md = new ModelDissector(); + + NewsgroupHelper helper = new NewsgroupHelper(); + helper.getEncoder().setTraceDictionary(traceDictionary); + helper.getBias().setTraceDictionary(traceDictionary); + + for (File file : permute(files, helper.getRandom()).subList(0, 500)) { + String ng = file.getParentFile().getName(); + int actual = dictionary.intern(ng); + + traceDictionary.clear(); + Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts); + md.update(v, traceDictionary, model); + } + + List<String> ngNames = new ArrayList<>(dictionary.values()); + List<ModelDissector.Weight> weights = md.summary(100); + System.out.println("============"); + System.out.println("Model Dissection"); + for (ModelDissector.Weight w : weights) { + System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s%n", + w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1), + w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2)); + } + } + + public static List<File> permute(Iterable<File> files, Random rand) { + List<File> r = new ArrayList<>(); + for (File file : files) { + int i = rand.nextInt(r.size() + 1); + if (i == r.size()) { + r.add(file); + } else { + r.add(r.get(i)); + r.set(i, file); + } + } + return r; + } + + static void analyzeState(SGDInfo info, int leakType, int k, State<AdaptiveLogisticRegression.Wrapper, + CrossFoldLearner> best) throws IOException { + int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length]; + int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length)); + double maxBeta; + double nonZeros; + double positive; + double norm; + + double lambda = 0; + double mu = 0; + + if (best != null) { + CrossFoldLearner state = best.getPayload().getLearner(); + info.setAverageCorrect(state.percentCorrect()); + info.setAverageLL(state.logLikelihood()); + + OnlineLogisticRegression model = state.getModels().get(0); + // finish off pending regularization + model.close(); + + Matrix beta = model.getBeta(); + maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); + nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { + @Override + public double apply(double v) { + return Math.abs(v) > 1.0e-6 ? 1 : 0; + } + }); + positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { + @Override + public double apply(double v) { + return v > 0 ? 1 : 0; + } + }); + norm = beta.aggregate(Functions.PLUS, Functions.ABS); + + lambda = best.getMappedParams()[0]; + mu = best.getMappedParams()[1]; + } else { + maxBeta = 0; + nonZeros = 0; + positive = 0; + norm = 0; + } + if (k % (bump * scale) == 0) { + if (best != null) { + File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group-" + k + ".model"); + ModelSerializer.writeBinary(modelFile.getAbsolutePath(), best.getPayload().getLearner().getModels().get(0)); + } + + info.setStep(info.getStep() + 0.25); + System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); + System.out.printf("%d\t%.3f\t%.2f\t%s%n", + k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java new file mode 100644 index 0000000..be55d43 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +final class SGDInfo { + + private double averageLL; + private double averageCorrect; + private double step; + private int[] bumps = {1, 2, 5}; + + double getAverageLL() { + return averageLL; + } + + void setAverageLL(double averageLL) { + this.averageLL = averageLL; + } + + double getAverageCorrect() { + return averageCorrect; + } + + void setAverageCorrect(double averageCorrect) { + this.averageCorrect = averageCorrect; + } + + double getStep() { + return step; + } + + void setStep(double step) { + this.step = step; + } + + int[] getBumps() { + return bumps; + } + + void setBumps(int[] bumps) { + this.bumps = bumps; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java new file mode 100644 index 0000000..b3da452 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.Joiner; +import com.google.common.base.Splitter; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.apache.commons.io.Charsets; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.list.IntArrayList; +import org.apache.mahout.math.stats.OnlineSummarizer; +import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; +import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.Closeable; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Shows how different encoding choices can make big speed differences. + * <p/> + * Run with command line options --generate 1000000 test.csv to generate a million data lines in + * test.csv. + * <p/> + * Run with command line options --parser test.csv to time how long it takes to parse and encode + * those million data points + * <p/> + * Run with command line options --fast test.csv to time how long it takes to parse and encode those + * million data points using byte-level parsing and direct value encoding. + * <p/> + * This doesn't demonstrate text encoding which is subject to somewhat different tricks. The basic + * idea of caching hash locations and byte level parsing still very much applies to text, however. + */ +public final class SimpleCsvExamples { + + public static final char SEPARATOR_CHAR = '\t'; + private static final int FIELDS = 100; + + private static final Logger log = LoggerFactory.getLogger(SimpleCsvExamples.class); + + private SimpleCsvExamples() {} + + public static void main(String[] args) throws IOException { + FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS]; + for (int i = 0; i < FIELDS; i++) { + encoder[i] = new ConstantValueEncoder("v" + 1); + } + + OnlineSummarizer[] s = new OnlineSummarizer[FIELDS]; + for (int i = 0; i < FIELDS; i++) { + s[i] = new OnlineSummarizer(); + } + long t0 = System.currentTimeMillis(); + Vector v = new DenseVector(1000); + if ("--generate".equals(args[0])) { + try (PrintWriter out = + new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8))) { + int n = Integer.parseInt(args[1]); + for (int i = 0; i < n; i++) { + Line x = Line.generate(); + out.println(x); + } + } + } else if ("--parse".equals(args[0])) { + try (BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8)){ + String line = in.readLine(); + while (line != null) { + v.assign(0); + Line x = new Line(line); + for (int i = 0; i < FIELDS; i++) { + s[i].add(x.getDouble(i)); + encoder[i].addToVector(x.get(i), v); + } + line = in.readLine(); + } + } + String separator = ""; + for (int i = 0; i < FIELDS; i++) { + System.out.printf("%s%.3f", separator, s[i].getMean()); + separator = ","; + } + } else if ("--fast".equals(args[0])) { + try (FastLineReader in = new FastLineReader(new FileInputStream(args[1]))){ + FastLine line = in.read(); + while (line != null) { + v.assign(0); + for (int i = 0; i < FIELDS; i++) { + double z = line.getDouble(i); + s[i].add(z); + encoder[i].addToVector((byte[]) null, z, v); + } + line = in.read(); + } + } + + String separator = ""; + for (int i = 0; i < FIELDS; i++) { + System.out.printf("%s%.3f", separator, s[i].getMean()); + separator = ","; + } + } + System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0); + } + + + private static final class Line { + private static final Splitter ON_TABS = Splitter.on(SEPARATOR_CHAR).trimResults(); + public static final Joiner WITH_COMMAS = Joiner.on(SEPARATOR_CHAR); + + public static final Random RAND = RandomUtils.getRandom(); + + private final List<String> data; + + private Line(CharSequence line) { + data = Lists.newArrayList(ON_TABS.split(line)); + } + + private Line() { + data = new ArrayList<>(); + } + + public double getDouble(int field) { + return Double.parseDouble(data.get(field)); + } + + /** + * Generate a random line with 20 fields each with integer values. + * + * @return A new line with data. + */ + public static Line generate() { + Line r = new Line(); + for (int i = 0; i < FIELDS; i++) { + double mean = ((i + 1) * 257) % 50 + 1; + r.data.add(Integer.toString(randomValue(mean))); + } + return r; + } + + /** + * Returns a random exponentially distributed integer with a particular mean value. This is + * just a way to create more small numbers than big numbers. + * + * @param mean mean of the distribution + * @return random exponentially distributed integer with the specific mean + */ + private static int randomValue(double mean) { + return (int) (-mean * Math.log1p(-RAND.nextDouble())); + } + + @Override + public String toString() { + return WITH_COMMAS.join(data); + } + + public String get(int field) { + return data.get(field); + } + } + + private static final class FastLine { + + private final ByteBuffer base; + private final IntArrayList start = new IntArrayList(); + private final IntArrayList length = new IntArrayList(); + + private FastLine(ByteBuffer base) { + this.base = base; + } + + public static FastLine read(ByteBuffer buf) { + FastLine r = new FastLine(buf); + r.start.add(buf.position()); + int offset = buf.position(); + while (offset < buf.limit()) { + int ch = buf.get(); + offset = buf.position(); + switch (ch) { + case '\n': + r.length.add(offset - r.start.get(r.length.size()) - 1); + return r; + case SEPARATOR_CHAR: + r.length.add(offset - r.start.get(r.length.size()) - 1); + r.start.add(offset); + break; + default: + // nothing to do for now + } + } + throw new IllegalArgumentException("Not enough bytes in buffer"); + } + + public double getDouble(int field) { + int offset = start.get(field); + int size = length.get(field); + switch (size) { + case 1: + return base.get(offset) - '0'; + case 2: + return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0'; + default: + double r = 0; + for (int i = 0; i < size; i++) { + r = 10 * r + base.get(offset + i) - '0'; + } + return r; + } + } + } + + private static final class FastLineReader implements Closeable { + private final InputStream in; + private final ByteBuffer buf = ByteBuffer.allocate(100000); + + private FastLineReader(InputStream in) throws IOException { + this.in = in; + buf.limit(0); + fillBuffer(); + } + + public FastLine read() throws IOException { + fillBuffer(); + if (buf.remaining() > 0) { + return FastLine.read(buf); + } else { + return null; + } + } + + private void fillBuffer() throws IOException { + if (buf.remaining() < 10000) { + buf.compact(); + int n = in.read(buf.array(), buf.position(), buf.remaining()); + if (n == -1) { + buf.flip(); + } else { + buf.limit(buf.position() + n); + buf.position(0); + } + } + } + + @Override + public void close() { + try { + Closeables.close(in, true); + } catch (IOException e) { + log.error(e.getMessage(), e); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java new file mode 100644 index 0000000..074f774 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java @@ -0,0 +1,152 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.PathFilter; +import org.apache.hadoop.io.Text; +import org.apache.mahout.classifier.ClassifierResult; +import org.apache.mahout.classifier.ResultAnalyzer; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.vectorizer.encoders.Dictionary; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; + +/** + * Run the ASF email, as trained by TrainASFEmail + */ +public final class TestASFEmail { + + private String inputFile; + private String modelFile; + + private TestASFEmail() {} + + public static void main(String[] args) throws IOException { + TestASFEmail runner = new TestASFEmail(); + if (runner.parseArgs(args)) { + runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + } + + public void run(PrintWriter output) throws IOException { + + File base = new File(inputFile); + //contains the best model + OnlineLogisticRegression classifier = + ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class); + + + Dictionary asfDictionary = new Dictionary(); + Configuration conf = new Configuration(); + PathFilter testFilter = new PathFilter() { + @Override + public boolean accept(Path path) { + return path.getName().contains("test"); + } + }; + SequenceFileDirIterator<Text, VectorWritable> iter = + new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter, + null, true, conf); + + long numItems = 0; + while (iter.hasNext()) { + Pair<Text, VectorWritable> next = iter.next(); + asfDictionary.intern(next.getFirst().toString()); + numItems++; + } + + System.out.println(numItems + " test files"); + ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT"); + iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter, + null, true, conf); + while (iter.hasNext()) { + Pair<Text, VectorWritable> next = iter.next(); + String ng = next.getFirst().toString(); + + int actual = asfDictionary.intern(ng); + Vector result = classifier.classifyFull(next.getSecond().get()); + int cat = result.maxValueIndex(); + double score = result.maxValue(); + double ll = classifier.logLikelihood(actual, next.getSecond().get()); + ClassifierResult cr = new ClassifierResult(asfDictionary.values().get(cat), score, ll); + ra.addInstance(asfDictionary.values().get(actual), cr); + + } + output.println(ra); + } + + boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help").withDescription("print this list").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFileOption = builder.withLongName("input") + .withRequired(true) + .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) + .withDescription("where to get training data") + .create(); + + Option modelFileOption = builder.withLongName("model") + .withRequired(true) + .withArgument(argumentBuilder.withName("model").withMaximum(1).create()) + .withDescription("where to get a model") + .create(); + + Group normalArgs = new GroupBuilder() + .withOption(help) + .withOption(inputFileOption) + .withOption(modelFileOption) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + inputFile = (String) cmdLine.getValue(inputFileOption); + modelFile = (String) cmdLine.getValue(modelFileOption); + return true; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java new file mode 100644 index 0000000..f0316e9 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java @@ -0,0 +1,141 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.mahout.classifier.ClassifierResult; +import org.apache.mahout.classifier.NewsgroupHelper; +import org.apache.mahout.classifier.ResultAnalyzer; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.Dictionary; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Run the 20 news groups test data through SGD, as trained by {@link org.apache.mahout.classifier.sgd.TrainNewsGroups}. + */ +public final class TestNewsGroups { + + private String inputFile; + private String modelFile; + + private TestNewsGroups() { + } + + public static void main(String[] args) throws IOException { + TestNewsGroups runner = new TestNewsGroups(); + if (runner.parseArgs(args)) { + runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + } + + public void run(PrintWriter output) throws IOException { + + File base = new File(inputFile); + //contains the best model + OnlineLogisticRegression classifier = + ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class); + + Dictionary newsGroups = new Dictionary(); + Multiset<String> overallCounts = HashMultiset.create(); + + List<File> files = new ArrayList<>(); + for (File newsgroup : base.listFiles()) { + if (newsgroup.isDirectory()) { + newsGroups.intern(newsgroup.getName()); + files.addAll(Arrays.asList(newsgroup.listFiles())); + } + } + System.out.println(files.size() + " test files"); + ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT"); + for (File file : files) { + String ng = file.getParentFile().getName(); + + int actual = newsGroups.intern(ng); + NewsgroupHelper helper = new NewsgroupHelper(); + //no leak type ensures this is a normal vector + Vector input = helper.encodeFeatureVector(file, actual, 0, overallCounts); + Vector result = classifier.classifyFull(input); + int cat = result.maxValueIndex(); + double score = result.maxValue(); + double ll = classifier.logLikelihood(actual, input); + ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll); + ra.addInstance(newsGroups.values().get(actual), cr); + + } + output.println(ra); + } + + boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help").withDescription("print this list").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFileOption = builder.withLongName("input") + .withRequired(true) + .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) + .withDescription("where to get training data") + .create(); + + Option modelFileOption = builder.withLongName("model") + .withRequired(true) + .withArgument(argumentBuilder.withName("model").withMaximum(1).create()) + .withDescription("where to get a model") + .create(); + + Group normalArgs = new GroupBuilder() + .withOption(help) + .withOption(inputFileOption) + .withOption(modelFileOption) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + inputFile = (String) cmdLine.getValue(inputFileOption); + modelFile = (String) cmdLine.getValue(modelFileOption); + return true; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java new file mode 100644 index 0000000..e681f92 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.HashMultiset; +import com.google.common.collect.Multiset; +import com.google.common.collect.Ordering; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.PathFilter; +import org.apache.hadoop.io.Text; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; +import org.apache.mahout.ep.State; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.vectorizer.encoders.Dictionary; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public final class TrainASFEmail extends AbstractJob { + + private TrainASFEmail() { + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption("categories", "nc", "The number of categories to train on", true); + addOption("cardinality", "c", "The size of the vectors to use", "100000"); + addOption("threads", "t", "The number of threads to use in the learner", "20"); + addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. " + + "Higher values require more memory.", "5"); + if (parseArguments(args) == null) { + return -1; + } + + File base = new File(getInputPath().toString()); + + Multiset<String> overallCounts = HashMultiset.create(); + File output = new File(getOutputPath().toString()); + output.mkdirs(); + int numCats = Integer.parseInt(getOption("categories")); + int cardinality = Integer.parseInt(getOption("cardinality", "100000")); + int threadCount = Integer.parseInt(getOption("threads", "20")); + int poolSize = Integer.parseInt(getOption("poolSize", "5")); + Dictionary asfDictionary = new Dictionary(); + AdaptiveLogisticRegression learningAlgorithm = + new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize); + learningAlgorithm.setInterval(800); + learningAlgorithm.setAveragingWindow(500); + + //We ran seq2encoded and split input already, so let's just build up the dictionary + Configuration conf = new Configuration(); + PathFilter trainFilter = new PathFilter() { + @Override + public boolean accept(Path path) { + return path.getName().contains("training"); + } + }; + SequenceFileDirIterator<Text, VectorWritable> iter = + new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf); + long numItems = 0; + while (iter.hasNext()) { + Pair<Text, VectorWritable> next = iter.next(); + asfDictionary.intern(next.getFirst().toString()); + numItems++; + } + + System.out.println(numItems + " training files"); + + SGDInfo info = new SGDInfo(); + + iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, + null, true, conf); + int k = 0; + while (iter.hasNext()) { + Pair<Text, VectorWritable> next = iter.next(); + String ng = next.getFirst().toString(); + int actual = asfDictionary.intern(ng); + //we already have encoded + learningAlgorithm.train(actual, next.getSecond().get()); + k++; + State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); + + SGDHelper.analyzeState(info, 0, k, best); + } + learningAlgorithm.close(); + //TODO: how to dissection since we aren't processing the files here + //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, overallCounts); + System.out.println("exiting main, writing model to " + output); + + ModelSerializer.writeBinary(output + "/asf.model", + learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); + + List<Integer> counts = new ArrayList<>(); + System.out.println("Word counts"); + for (String count : overallCounts.elementSet()) { + counts.add(overallCounts.count(count)); + } + Collections.sort(counts, Ordering.natural().reverse()); + k = 0; + for (Integer count : counts) { + System.out.println(k + "\t" + count); + k++; + if (k > 1000) { + break; + } + } + return 0; + } + + public static void main(String[] args) throws Exception { + TrainASFEmail trainer = new TrainASFEmail(); + trainer.run(args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java new file mode 100644 index 0000000..defb5b9 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.io.Resources; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper; +import org.apache.mahout.ep.State; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +public final class TrainAdaptiveLogistic { + + private static String inputFile; + private static String outputFile; + private static AdaptiveLogisticModelParameters lmp; + private static int passes; + private static boolean showperf; + private static int skipperfnum = 99; + private static AdaptiveLogisticRegression model; + + private TrainAdaptiveLogistic() { + } + + public static void main(String[] args) throws Exception { + mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + + static void mainToOutput(String[] args, PrintWriter output) throws Exception { + if (parseArgs(args)) { + + CsvRecordFactory csv = lmp.getCsvRecordFactory(); + model = lmp.createAdaptiveLogisticRegression(); + State<Wrapper, CrossFoldLearner> best; + CrossFoldLearner learner = null; + + int k = 0; + for (int pass = 0; pass < passes; pass++) { + BufferedReader in = open(inputFile); + + // read variable names + csv.firstLine(in.readLine()); + + String line = in.readLine(); + while (line != null) { + // for each new line, get target and predictors + Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); + int targetValue = csv.processLine(line, input); + + // update model + model.train(targetValue, input); + k++; + + if (showperf && (k % (skipperfnum + 1) == 0)) { + + best = model.getBest(); + if (best != null) { + learner = best.getPayload().getLearner(); + } + if (learner != null) { + double averageCorrect = learner.percentCorrect(); + double averageLL = learner.logLikelihood(); + output.printf("%d\t%.3f\t%.2f%n", + k, averageLL, averageCorrect * 100); + } else { + output.printf(Locale.ENGLISH, + "%10d %2d %s%n", k, targetValue, + "AdaptiveLogisticRegression has not found a good model ......"); + } + } + line = in.readLine(); + } + in.close(); + } + + best = model.getBest(); + if (best != null) { + learner = best.getPayload().getLearner(); + } + if (learner == null) { + output.println("AdaptiveLogisticRegression has failed to train a model."); + return; + } + + try (OutputStream modelOutput = new FileOutputStream(outputFile)) { + lmp.saveTo(modelOutput); + } + + OnlineLogisticRegression lr = learner.getModels().get(0); + output.println(lmp.getNumFeatures()); + output.println(lmp.getTargetVariable() + " ~ "); + String sep = ""; + for (String v : csv.getTraceDictionary().keySet()) { + double weight = predictorWeight(lr, 0, csv, v); + if (weight != 0) { + output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); + sep = " + "; + } + } + output.printf("%n"); + + for (int row = 0; row < lr.getBeta().numRows(); row++) { + for (String key : csv.getTraceDictionary().keySet()) { + double weight = predictorWeight(lr, row, csv, key); + if (weight != 0) { + output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); + } + } + for (int column = 0; column < lr.getBeta().numCols(); column++) { + output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); + } + output.println(); + } + } + + } + + private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) { + double weight = 0; + for (Integer column : csv.getTraceDictionary().get(predictor)) { + weight += lr.getBeta().get(row, column); + } + return weight; + } + + private static boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help") + .withDescription("print this list").create(); + + Option quiet = builder.withLongName("quiet") + .withDescription("be extra quiet").create(); + + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option showperf = builder + .withLongName("showperf") + .withDescription("output performance measures during training") + .create(); + + Option inputFile = builder + .withLongName("input") + .withRequired(true) + .withArgument( + argumentBuilder.withName("input").withMaximum(1) + .create()) + .withDescription("where to get training data").create(); + + Option outputFile = builder + .withLongName("output") + .withRequired(true) + .withArgument( + argumentBuilder.withName("output").withMaximum(1) + .create()) + .withDescription("where to write the model content").create(); + + Option threads = builder.withLongName("threads") + .withArgument( + argumentBuilder.withName("threads").withDefault("4").create()) + .withDescription("the number of threads AdaptiveLogisticRegression uses") + .create(); + + + Option predictors = builder.withLongName("predictors") + .withRequired(true) + .withArgument(argumentBuilder.withName("predictors").create()) + .withDescription("a list of predictor variables").create(); + + Option types = builder + .withLongName("types") + .withRequired(true) + .withArgument(argumentBuilder.withName("types").create()) + .withDescription( + "a list of predictor variable types (numeric, word, or text)") + .create(); + + Option target = builder + .withLongName("target") + .withDescription("the name of the target variable") + .withRequired(true) + .withArgument( + argumentBuilder.withName("target").withMaximum(1) + .create()) + .create(); + + Option targetCategories = builder + .withLongName("categories") + .withDescription("the number of target categories to be considered") + .withRequired(true) + .withArgument(argumentBuilder.withName("categories").withMaximum(1).create()) + .create(); + + + Option features = builder + .withLongName("features") + .withDescription("the number of internal hashed features to use") + .withArgument( + argumentBuilder.withName("numFeatures") + .withDefault("1000").withMaximum(1).create()) + .create(); + + Option passes = builder + .withLongName("passes") + .withDescription("the number of times to pass over the input data") + .withArgument( + argumentBuilder.withName("passes").withDefault("2") + .withMaximum(1).create()) + .create(); + + Option interval = builder.withLongName("interval") + .withArgument( + argumentBuilder.withName("interval").withDefault("500").create()) + .withDescription("the interval property of AdaptiveLogisticRegression") + .create(); + + Option window = builder.withLongName("window") + .withArgument( + argumentBuilder.withName("window").withDefault("800").create()) + .withDescription("the average propery of AdaptiveLogisticRegression") + .create(); + + Option skipperfnum = builder.withLongName("skipperfnum") + .withArgument( + argumentBuilder.withName("skipperfnum").withDefault("99").create()) + .withDescription("show performance measures every (skipperfnum + 1) rows") + .create(); + + Option prior = builder.withLongName("prior") + .withArgument( + argumentBuilder.withName("prior").withDefault("L1").create()) + .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up") + .create(); + + Option priorOption = builder.withLongName("prioroption") + .withArgument( + argumentBuilder.withName("prioroption").create()) + .withDescription("constructor parameter for ElasticBandPrior and TPrior") + .create(); + + Option auc = builder.withLongName("auc") + .withArgument( + argumentBuilder.withName("auc").withDefault("global").create()) + .withDescription("the auc to use: global or grouped") + .create(); + + + + Group normalArgs = new GroupBuilder().withOption(help) + .withOption(quiet).withOption(inputFile).withOption(outputFile) + .withOption(target).withOption(targetCategories) + .withOption(predictors).withOption(types).withOption(passes) + .withOption(interval).withOption(window).withOption(threads) + .withOption(prior).withOption(features).withOption(showperf) + .withOption(skipperfnum).withOption(priorOption).withOption(auc) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile); + TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine, + outputFile); + + List<String> typeList = new ArrayList<>(); + for (Object x : cmdLine.getValues(types)) { + typeList.add(x.toString()); + } + + List<String> predictorList = new ArrayList<>(); + for (Object x : cmdLine.getValues(predictors)) { + predictorList.add(x.toString()); + } + + lmp = new AdaptiveLogisticModelParameters(); + lmp.setTargetVariable(getStringArgument(cmdLine, target)); + lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories)); + lmp.setNumFeatures(getIntegerArgument(cmdLine, features)); + lmp.setInterval(getIntegerArgument(cmdLine, interval)); + lmp.setAverageWindow(getIntegerArgument(cmdLine, window)); + lmp.setThreads(getIntegerArgument(cmdLine, threads)); + lmp.setAuc(getStringArgument(cmdLine, auc)); + lmp.setPrior(getStringArgument(cmdLine, prior)); + if (cmdLine.getValue(priorOption) != null) { + lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption)); + } + lmp.setTypeMap(predictorList, typeList); + TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf); + TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum); + TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes); + + lmp.checkParameters(); + + return true; + } + + private static String getStringArgument(CommandLine cmdLine, + Option inputFile) { + return (String) cmdLine.getValue(inputFile); + } + + private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { + return cmdLine.hasOption(option); + } + + private static int getIntegerArgument(CommandLine cmdLine, Option features) { + return Integer.parseInt((String) cmdLine.getValue(features)); + } + + private static double getDoubleArgument(CommandLine cmdLine, Option op) { + return Double.parseDouble((String) cmdLine.getValue(op)); + } + + public static AdaptiveLogisticRegression getModel() { + return model; + } + + public static LogisticModelParameters getParameters() { + return lmp; + } + + static BufferedReader open(String inputFile) throws IOException { + InputStream in; + try { + in = Resources.getResource(inputFile).openStream(); + } catch (IllegalArgumentException e) { + in = new FileInputStream(new File(inputFile)); + } + return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8)); + } + +}
