http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java deleted file mode 100644 index b2ce8b1..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticModelParameters.java +++ /dev/null @@ -1,236 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import org.apache.mahout.math.stats.GlobalOnlineAuc; -import org.apache.mahout.math.stats.GroupedOnlineAuc; -import org.apache.mahout.math.stats.OnlineAuc; - -import java.io.DataInput; -import java.io.DataInputStream; -import java.io.DataOutput; -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; - -public class AdaptiveLogisticModelParameters extends LogisticModelParameters { - - private AdaptiveLogisticRegression alr; - private int interval = 800; - private int averageWindow = 500; - private int threads = 4; - private String prior = "L1"; - private double priorOption = Double.NaN; - private String auc = null; - - public AdaptiveLogisticRegression createAdaptiveLogisticRegression() { - - if (alr == null) { - alr = new AdaptiveLogisticRegression(getMaxTargetCategories(), - getNumFeatures(), createPrior(prior, priorOption)); - alr.setInterval(interval); - alr.setAveragingWindow(averageWindow); - alr.setThreadCount(threads); - alr.setAucEvaluator(createAUC(auc)); - } - return alr; - } - - public void checkParameters() { - if (prior != null) { - String priorUppercase = prior.toUpperCase(Locale.ENGLISH).trim(); - if (("TP".equals(priorUppercase) || "EBP".equals(priorUppercase)) && Double.isNaN(priorOption)) { - throw new IllegalArgumentException("You must specify a double value for TPrior and ElasticBandPrior."); - } - } - } - - private static PriorFunction createPrior(String cmd, double priorOption) { - if (cmd == null) { - return null; - } - if ("L1".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { - return new L1(); - } - if ("L2".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { - return new L2(); - } - if ("UP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { - return new UniformPrior(); - } - if ("TP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { - return new TPrior(priorOption); - } - if ("EBP".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { - return new ElasticBandPrior(priorOption); - } - - return null; - } - - private static OnlineAuc createAUC(String cmd) { - if (cmd == null) { - return null; - } - if ("GLOBAL".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { - return new GlobalOnlineAuc(); - } - if ("GROUPED".equals(cmd.toUpperCase(Locale.ENGLISH).trim())) { - return new GroupedOnlineAuc(); - } - return null; - } - - @Override - public void saveTo(OutputStream out) throws IOException { - if (alr != null) { - alr.close(); - } - setTargetCategories(getCsvRecordFactory().getTargetCategories()); - write(new DataOutputStream(out)); - } - - @Override - public void write(DataOutput out) throws IOException { - out.writeUTF(getTargetVariable()); - out.writeInt(getTypeMap().size()); - for (Map.Entry<String, String> entry : getTypeMap().entrySet()) { - out.writeUTF(entry.getKey()); - out.writeUTF(entry.getValue()); - } - out.writeInt(getNumFeatures()); - out.writeInt(getMaxTargetCategories()); - out.writeInt(getTargetCategories().size()); - for (String category : getTargetCategories()) { - out.writeUTF(category); - } - - out.writeInt(interval); - out.writeInt(averageWindow); - out.writeInt(threads); - out.writeUTF(prior); - out.writeDouble(priorOption); - out.writeUTF(auc); - - // skip csv - alr.write(out); - } - - @Override - public void readFields(DataInput in) throws IOException { - setTargetVariable(in.readUTF()); - int typeMapSize = in.readInt(); - Map<String, String> typeMap = new HashMap<>(typeMapSize); - for (int i = 0; i < typeMapSize; i++) { - String key = in.readUTF(); - String value = in.readUTF(); - typeMap.put(key, value); - } - setTypeMap(typeMap); - - setNumFeatures(in.readInt()); - setMaxTargetCategories(in.readInt()); - int targetCategoriesSize = in.readInt(); - List<String> targetCategories = new ArrayList<>(targetCategoriesSize); - for (int i = 0; i < targetCategoriesSize; i++) { - targetCategories.add(in.readUTF()); - } - setTargetCategories(targetCategories); - - interval = in.readInt(); - averageWindow = in.readInt(); - threads = in.readInt(); - prior = in.readUTF(); - priorOption = in.readDouble(); - auc = in.readUTF(); - - alr = new AdaptiveLogisticRegression(); - alr.readFields(in); - } - - - private static AdaptiveLogisticModelParameters loadFromStream(InputStream in) throws IOException { - AdaptiveLogisticModelParameters result = new AdaptiveLogisticModelParameters(); - result.readFields(new DataInputStream(in)); - return result; - } - - public static AdaptiveLogisticModelParameters loadFromFile(File in) throws IOException { - try (InputStream input = new FileInputStream(in)) { - return loadFromStream(input); - } - } - - public int getInterval() { - return interval; - } - - public void setInterval(int interval) { - this.interval = interval; - } - - public int getAverageWindow() { - return averageWindow; - } - - public void setAverageWindow(int averageWindow) { - this.averageWindow = averageWindow; - } - - public int getThreads() { - return threads; - } - - public void setThreads(int threads) { - this.threads = threads; - } - - public String getPrior() { - return prior; - } - - public void setPrior(String prior) { - this.prior = prior; - } - - public String getAuc() { - return auc; - } - - public void setAuc(String auc) { - this.auc = auc; - } - - public double getPriorOption() { - return priorOption; - } - - public void setPriorOption(double priorOption) { - this.priorOption = priorOption; - } - - -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java deleted file mode 100644 index e762924..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/LogisticModelParameters.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.base.Preconditions; -import com.google.common.io.Closeables; -import java.io.DataInput; -import java.io.DataInputStream; -import java.io.DataOutput; -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import org.apache.hadoop.io.Writable; - -/** - * Encapsulates everything we need to know about a model and how it reads and vectorizes its input. - * This encapsulation allows us to coherently save and restore a model from a file. This also - * allows us to keep command line arguments that affect learning in a coherent way. - */ -public class LogisticModelParameters implements Writable { - private String targetVariable; - private Map<String, String> typeMap; - private int numFeatures; - private boolean useBias; - private int maxTargetCategories; - private List<String> targetCategories; - private double lambda; - private double learningRate; - private CsvRecordFactory csv; - private OnlineLogisticRegression lr; - - /** - * Returns a CsvRecordFactory compatible with this logistic model. The reason that this is tied - * in here is so that we have access to the list of target categories when it comes time to save - * the model. If the input isn't CSV, then calling setTargetCategories before calling saveTo will - * suffice. - * - * @return The CsvRecordFactory. - */ - public CsvRecordFactory getCsvRecordFactory() { - if (csv == null) { - csv = new CsvRecordFactory(getTargetVariable(), getTypeMap()) - .maxTargetValue(getMaxTargetCategories()) - .includeBiasTerm(useBias()); - if (targetCategories != null) { - csv.defineTargetCategories(targetCategories); - } - } - return csv; - } - - /** - * Creates a logistic regression trainer using the parameters collected here. - * - * @return The newly allocated OnlineLogisticRegression object - */ - public OnlineLogisticRegression createRegression() { - if (lr == null) { - lr = new OnlineLogisticRegression(getMaxTargetCategories(), getNumFeatures(), new L1()) - .lambda(getLambda()) - .learningRate(getLearningRate()) - .alpha(1 - 1.0e-3); - } - return lr; - } - - /** - * Saves a model to an output stream. - */ - public void saveTo(OutputStream out) throws IOException { - Closeables.close(lr, false); - targetCategories = getCsvRecordFactory().getTargetCategories(); - write(new DataOutputStream(out)); - } - - /** - * Reads a model from a stream. - */ - public static LogisticModelParameters loadFrom(InputStream in) throws IOException { - LogisticModelParameters result = new LogisticModelParameters(); - result.readFields(new DataInputStream(in)); - return result; - } - - /** - * Reads a model from a file. - * @throws IOException If there is an error opening or closing the file. - */ - public static LogisticModelParameters loadFrom(File in) throws IOException { - try (InputStream input = new FileInputStream(in)) { - return loadFrom(input); - } - } - - - @Override - public void write(DataOutput out) throws IOException { - out.writeUTF(targetVariable); - out.writeInt(typeMap.size()); - for (Map.Entry<String,String> entry : typeMap.entrySet()) { - out.writeUTF(entry.getKey()); - out.writeUTF(entry.getValue()); - } - out.writeInt(numFeatures); - out.writeBoolean(useBias); - out.writeInt(maxTargetCategories); - - if (targetCategories == null) { - out.writeInt(0); - } else { - out.writeInt(targetCategories.size()); - for (String category : targetCategories) { - out.writeUTF(category); - } - } - out.writeDouble(lambda); - out.writeDouble(learningRate); - // skip csv - lr.write(out); - } - - @Override - public void readFields(DataInput in) throws IOException { - targetVariable = in.readUTF(); - int typeMapSize = in.readInt(); - typeMap = new HashMap<>(typeMapSize); - for (int i = 0; i < typeMapSize; i++) { - String key = in.readUTF(); - String value = in.readUTF(); - typeMap.put(key, value); - } - numFeatures = in.readInt(); - useBias = in.readBoolean(); - maxTargetCategories = in.readInt(); - int targetCategoriesSize = in.readInt(); - targetCategories = new ArrayList<>(targetCategoriesSize); - for (int i = 0; i < targetCategoriesSize; i++) { - targetCategories.add(in.readUTF()); - } - lambda = in.readDouble(); - learningRate = in.readDouble(); - csv = null; - lr = new OnlineLogisticRegression(); - lr.readFields(in); - } - - /** - * Sets the types of the predictors. This will later be used when reading CSV data. If you don't - * use the CSV data and convert to vectors on your own, you don't need to call this. - * - * @param predictorList The list of variable names. - * @param typeList The list of types in the format preferred by CsvRecordFactory. - */ - public void setTypeMap(Iterable<String> predictorList, List<String> typeList) { - Preconditions.checkArgument(!typeList.isEmpty(), "Must have at least one type specifier"); - typeMap = new HashMap<>(); - Iterator<String> iTypes = typeList.iterator(); - String lastType = null; - for (Object x : predictorList) { - // type list can be short .. we just repeat last spec - if (iTypes.hasNext()) { - lastType = iTypes.next(); - } - typeMap.put(x.toString(), lastType); - } - } - - /** - * Sets the target variable. If you don't use the CSV record factory, then this is irrelevant. - * - * @param targetVariable The name of the target variable. - */ - public void setTargetVariable(String targetVariable) { - this.targetVariable = targetVariable; - } - - /** - * Sets the number of target categories to be considered. - * - * @param maxTargetCategories The number of target categories. - */ - public void setMaxTargetCategories(int maxTargetCategories) { - this.maxTargetCategories = maxTargetCategories; - } - - public void setNumFeatures(int numFeatures) { - this.numFeatures = numFeatures; - } - - public void setTargetCategories(List<String> targetCategories) { - this.targetCategories = targetCategories; - maxTargetCategories = targetCategories.size(); - } - - public List<String> getTargetCategories() { - return this.targetCategories; - } - - public void setUseBias(boolean useBias) { - this.useBias = useBias; - } - - public boolean useBias() { - return useBias; - } - - public String getTargetVariable() { - return targetVariable; - } - - public Map<String, String> getTypeMap() { - return typeMap; - } - - public void setTypeMap(Map<String, String> map) { - this.typeMap = map; - } - - public int getNumFeatures() { - return numFeatures; - } - - public int getMaxTargetCategories() { - return maxTargetCategories; - } - - public double getLambda() { - return lambda; - } - - public void setLambda(double lambda) { - this.lambda = lambda; - } - - public double getLearningRate() { - return learningRate; - } - - public void setLearningRate(double learningRate) { - this.learningRate = learningRate; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java deleted file mode 100644 index 3ec6a06..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/PrintResourceOrFile.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.base.Preconditions; - -import java.io.BufferedReader; - -/** - * Uses the same logic as TrainLogistic and RunLogistic for finding an input, but instead - * of processing the input, this class just prints the input to standard out. - */ -public final class PrintResourceOrFile { - - private PrintResourceOrFile() { - } - - public static void main(String[] args) throws Exception { - Preconditions.checkArgument(args.length == 1, "Must have a single argument that names a file or resource."); - try (BufferedReader in = TrainLogistic.open(args[0])){ - String line; - while ((line = in.readLine()) != null) { - System.out.println(line); - } - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java deleted file mode 100644 index 678a8f5..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunAdaptiveLogistic.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.commons.io.Charsets; -import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper; -import org.apache.mahout.ep.State; -import org.apache.mahout.math.SequentialAccessSparseVector; -import org.apache.mahout.math.Vector; - -import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileOutputStream; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.util.HashMap; -import java.util.Map; - -public final class RunAdaptiveLogistic { - - private static String inputFile; - private static String modelFile; - private static String outputFile; - private static String idColumn; - private static boolean maxScoreOnly; - - private RunAdaptiveLogistic() { - } - - public static void main(String[] args) throws Exception { - mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); - } - - static void mainToOutput(String[] args, PrintWriter output) throws Exception { - if (!parseArgs(args)) { - return; - } - AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters - .loadFromFile(new File(modelFile)); - - CsvRecordFactory csv = lmp.getCsvRecordFactory(); - csv.setIdName(idColumn); - - AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression(); - - State<Wrapper, CrossFoldLearner> best = lr.getBest(); - if (best == null) { - output.println("AdaptiveLogisticRegression has not be trained probably."); - return; - } - CrossFoldLearner learner = best.getPayload().getLearner(); - - BufferedReader in = TrainAdaptiveLogistic.open(inputFile); - int k = 0; - - try (BufferedWriter out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputFile), - Charsets.UTF_8))) { - out.write(idColumn + ",target,score"); - out.newLine(); - - String line = in.readLine(); - csv.firstLine(line); - line = in.readLine(); - Map<String, Double> results = new HashMap<>(); - while (line != null) { - Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); - csv.processLine(line, v, false); - Vector scores = learner.classifyFull(v); - results.clear(); - if (maxScoreOnly) { - results.put(csv.getTargetLabel(scores.maxValueIndex()), - scores.maxValue()); - } else { - for (int i = 0; i < scores.size(); i++) { - results.put(csv.getTargetLabel(i), scores.get(i)); - } - } - - for (Map.Entry<String, Double> entry : results.entrySet()) { - out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue()); - out.newLine(); - } - k++; - if (k % 100 == 0) { - output.println(k + " records processed"); - } - line = in.readLine(); - } - out.flush(); - } - output.println(k + " records processed totally."); - } - - private static boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help") - .withDescription("print this list").create(); - - Option quiet = builder.withLongName("quiet") - .withDescription("be extra quiet").create(); - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option inputFileOption = builder - .withLongName("input") - .withRequired(true) - .withArgument( - argumentBuilder.withName("input").withMaximum(1) - .create()) - .withDescription("where to get training data").create(); - - Option modelFileOption = builder - .withLongName("model") - .withRequired(true) - .withArgument( - argumentBuilder.withName("model").withMaximum(1) - .create()) - .withDescription("where to get the trained model").create(); - - Option outputFileOption = builder - .withLongName("output") - .withRequired(true) - .withDescription("the file path to output scores") - .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) - .create(); - - Option idColumnOption = builder - .withLongName("idcolumn") - .withRequired(true) - .withDescription("the name of the id column for each record") - .withArgument(argumentBuilder.withName("idcolumn").withMaximum(1).create()) - .create(); - - Option maxScoreOnlyOption = builder - .withLongName("maxscoreonly") - .withDescription("only output the target label with max scores") - .create(); - - Group normalArgs = new GroupBuilder() - .withOption(help).withOption(quiet) - .withOption(inputFileOption).withOption(modelFileOption) - .withOption(outputFileOption).withOption(idColumnOption) - .withOption(maxScoreOnlyOption) - .create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); - CommandLine cmdLine = parser.parseAndHelp(args); - - if (cmdLine == null) { - return false; - } - - inputFile = getStringArgument(cmdLine, inputFileOption); - modelFile = getStringArgument(cmdLine, modelFileOption); - outputFile = getStringArgument(cmdLine, outputFileOption); - idColumn = getStringArgument(cmdLine, idColumnOption); - maxScoreOnly = getBooleanArgument(cmdLine, maxScoreOnlyOption); - return true; - } - - private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { - return cmdLine.hasOption(option); - } - - private static String getStringArgument(CommandLine cmdLine, Option inputFile) { - return (String) cmdLine.getValue(inputFile); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java deleted file mode 100644 index 2d57016..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.commons.io.Charsets; -import org.apache.mahout.classifier.evaluation.Auc; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.SequentialAccessSparseVector; -import org.apache.mahout.math.Vector; - -import java.io.BufferedReader; -import java.io.File; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.util.Locale; - -public final class RunLogistic { - - private static String inputFile; - private static String modelFile; - private static boolean showAuc; - private static boolean showScores; - private static boolean showConfusion; - - private RunLogistic() { - } - - public static void main(String[] args) throws Exception { - mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); - } - - static void mainToOutput(String[] args, PrintWriter output) throws Exception { - if (parseArgs(args)) { - if (!showAuc && !showConfusion && !showScores) { - showAuc = true; - showConfusion = true; - } - - Auc collector = new Auc(); - LogisticModelParameters lmp = LogisticModelParameters.loadFrom(new File(modelFile)); - - CsvRecordFactory csv = lmp.getCsvRecordFactory(); - OnlineLogisticRegression lr = lmp.createRegression(); - BufferedReader in = TrainLogistic.open(inputFile); - String line = in.readLine(); - csv.firstLine(line); - line = in.readLine(); - if (showScores) { - output.println("\"target\",\"model-output\",\"log-likelihood\""); - } - while (line != null) { - Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); - int target = csv.processLine(line, v); - - double score = lr.classifyScalar(v); - if (showScores) { - output.printf(Locale.ENGLISH, "%d,%.3f,%.6f%n", target, score, lr.logLikelihood(target, v)); - } - collector.add(target, score); - line = in.readLine(); - } - - if (showAuc) { - output.printf(Locale.ENGLISH, "AUC = %.2f%n", collector.auc()); - } - if (showConfusion) { - Matrix m = collector.confusion(); - output.printf(Locale.ENGLISH, "confusion: [[%.1f, %.1f], [%.1f, %.1f]]%n", - m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); - m = collector.entropy(); - output.printf(Locale.ENGLISH, "entropy: [[%.1f, %.1f], [%.1f, %.1f]]%n", - m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1)); - } - } - } - - private static boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help").withDescription("print this list").create(); - - Option quiet = builder.withLongName("quiet").withDescription("be extra quiet").create(); - - Option auc = builder.withLongName("auc").withDescription("print AUC").create(); - Option confusion = builder.withLongName("confusion").withDescription("print confusion matrix").create(); - - Option scores = builder.withLongName("scores").withDescription("print scores").create(); - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option inputFileOption = builder.withLongName("input") - .withRequired(true) - .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) - .withDescription("where to get training data") - .create(); - - Option modelFileOption = builder.withLongName("model") - .withRequired(true) - .withArgument(argumentBuilder.withName("model").withMaximum(1).create()) - .withDescription("where to get a model") - .create(); - - Group normalArgs = new GroupBuilder() - .withOption(help) - .withOption(quiet) - .withOption(auc) - .withOption(scores) - .withOption(confusion) - .withOption(inputFileOption) - .withOption(modelFileOption) - .create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); - CommandLine cmdLine = parser.parseAndHelp(args); - - if (cmdLine == null) { - return false; - } - - inputFile = getStringArgument(cmdLine, inputFileOption); - modelFile = getStringArgument(cmdLine, modelFileOption); - showAuc = getBooleanArgument(cmdLine, auc); - showScores = getBooleanArgument(cmdLine, scores); - showConfusion = getBooleanArgument(cmdLine, confusion); - - return true; - } - - private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { - return cmdLine.hasOption(option); - } - - private static String getStringArgument(CommandLine cmdLine, Option inputFile) { - return (String) cmdLine.getValue(inputFile); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java deleted file mode 100644 index c657803..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java +++ /dev/null @@ -1,151 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.collect.Multiset; -import org.apache.mahout.classifier.NewsgroupHelper; -import org.apache.mahout.ep.State; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.function.DoubleFunction; -import org.apache.mahout.math.function.Functions; -import org.apache.mahout.vectorizer.encoders.Dictionary; - -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Random; -import java.util.Set; -import java.util.TreeMap; - -public final class SGDHelper { - - private static final String[] LEAK_LABELS = {"none", "month-year", "day-month-year"}; - - private SGDHelper() { - } - - public static void dissect(int leakType, - Dictionary dictionary, - AdaptiveLogisticRegression learningAlgorithm, - Iterable<File> files, Multiset<String> overallCounts) throws IOException { - CrossFoldLearner model = learningAlgorithm.getBest().getPayload().getLearner(); - model.close(); - - Map<String, Set<Integer>> traceDictionary = new TreeMap<>(); - ModelDissector md = new ModelDissector(); - - NewsgroupHelper helper = new NewsgroupHelper(); - helper.getEncoder().setTraceDictionary(traceDictionary); - helper.getBias().setTraceDictionary(traceDictionary); - - for (File file : permute(files, helper.getRandom()).subList(0, 500)) { - String ng = file.getParentFile().getName(); - int actual = dictionary.intern(ng); - - traceDictionary.clear(); - Vector v = helper.encodeFeatureVector(file, actual, leakType, overallCounts); - md.update(v, traceDictionary, model); - } - - List<String> ngNames = new ArrayList<>(dictionary.values()); - List<ModelDissector.Weight> weights = md.summary(100); - System.out.println("============"); - System.out.println("Model Dissection"); - for (ModelDissector.Weight w : weights) { - System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s%n", - w.getFeature(), w.getWeight(), ngNames.get(w.getMaxImpact() + 1), - w.getCategory(1), w.getWeight(1), w.getCategory(2), w.getWeight(2)); - } - } - - public static List<File> permute(Iterable<File> files, Random rand) { - List<File> r = new ArrayList<>(); - for (File file : files) { - int i = rand.nextInt(r.size() + 1); - if (i == r.size()) { - r.add(file); - } else { - r.add(r.get(i)); - r.set(i, file); - } - } - return r; - } - - static void analyzeState(SGDInfo info, int leakType, int k, State<AdaptiveLogisticRegression.Wrapper, - CrossFoldLearner> best) throws IOException { - int bump = info.getBumps()[(int) Math.floor(info.getStep()) % info.getBumps().length]; - int scale = (int) Math.pow(10, Math.floor(info.getStep() / info.getBumps().length)); - double maxBeta; - double nonZeros; - double positive; - double norm; - - double lambda = 0; - double mu = 0; - - if (best != null) { - CrossFoldLearner state = best.getPayload().getLearner(); - info.setAverageCorrect(state.percentCorrect()); - info.setAverageLL(state.logLikelihood()); - - OnlineLogisticRegression model = state.getModels().get(0); - // finish off pending regularization - model.close(); - - Matrix beta = model.getBeta(); - maxBeta = beta.aggregate(Functions.MAX, Functions.ABS); - nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() { - @Override - public double apply(double v) { - return Math.abs(v) > 1.0e-6 ? 1 : 0; - } - }); - positive = beta.aggregate(Functions.PLUS, new DoubleFunction() { - @Override - public double apply(double v) { - return v > 0 ? 1 : 0; - } - }); - norm = beta.aggregate(Functions.PLUS, Functions.ABS); - - lambda = best.getMappedParams()[0]; - mu = best.getMappedParams()[1]; - } else { - maxBeta = 0; - nonZeros = 0; - positive = 0; - norm = 0; - } - if (k % (bump * scale) == 0) { - if (best != null) { - File modelFile = new File(System.getProperty("java.io.tmpdir"), "news-group-" + k + ".model"); - ModelSerializer.writeBinary(modelFile.getAbsolutePath(), best.getPayload().getLearner().getModels().get(0)); - } - - info.setStep(info.getStep() + 0.25); - System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, nonZeros, positive, norm, lambda, mu); - System.out.printf("%d\t%.3f\t%.2f\t%s%n", - k, info.getAverageLL(), info.getAverageCorrect() * 100, LEAK_LABELS[leakType % 3]); - } - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java deleted file mode 100644 index be55d43..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java +++ /dev/null @@ -1,59 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -final class SGDInfo { - - private double averageLL; - private double averageCorrect; - private double step; - private int[] bumps = {1, 2, 5}; - - double getAverageLL() { - return averageLL; - } - - void setAverageLL(double averageLL) { - this.averageLL = averageLL; - } - - double getAverageCorrect() { - return averageCorrect; - } - - void setAverageCorrect(double averageCorrect) { - this.averageCorrect = averageCorrect; - } - - double getStep() { - return step; - } - - void setStep(double step) { - this.step = step; - } - - int[] getBumps() { - return bumps; - } - - void setBumps(int[] bumps) { - this.bumps = bumps; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java deleted file mode 100644 index b3da452..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/SimpleCsvExamples.java +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.base.Joiner; -import com.google.common.base.Splitter; -import com.google.common.collect.Lists; -import com.google.common.io.Closeables; -import com.google.common.io.Files; -import org.apache.commons.io.Charsets; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.math.DenseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.list.IntArrayList; -import org.apache.mahout.math.stats.OnlineSummarizer; -import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; -import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.BufferedReader; -import java.io.Closeable; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Random; - -/** - * Shows how different encoding choices can make big speed differences. - * <p/> - * Run with command line options --generate 1000000 test.csv to generate a million data lines in - * test.csv. - * <p/> - * Run with command line options --parser test.csv to time how long it takes to parse and encode - * those million data points - * <p/> - * Run with command line options --fast test.csv to time how long it takes to parse and encode those - * million data points using byte-level parsing and direct value encoding. - * <p/> - * This doesn't demonstrate text encoding which is subject to somewhat different tricks. The basic - * idea of caching hash locations and byte level parsing still very much applies to text, however. - */ -public final class SimpleCsvExamples { - - public static final char SEPARATOR_CHAR = '\t'; - private static final int FIELDS = 100; - - private static final Logger log = LoggerFactory.getLogger(SimpleCsvExamples.class); - - private SimpleCsvExamples() {} - - public static void main(String[] args) throws IOException { - FeatureVectorEncoder[] encoder = new FeatureVectorEncoder[FIELDS]; - for (int i = 0; i < FIELDS; i++) { - encoder[i] = new ConstantValueEncoder("v" + 1); - } - - OnlineSummarizer[] s = new OnlineSummarizer[FIELDS]; - for (int i = 0; i < FIELDS; i++) { - s[i] = new OnlineSummarizer(); - } - long t0 = System.currentTimeMillis(); - Vector v = new DenseVector(1000); - if ("--generate".equals(args[0])) { - try (PrintWriter out = - new PrintWriter(new OutputStreamWriter(new FileOutputStream(new File(args[2])), Charsets.UTF_8))) { - int n = Integer.parseInt(args[1]); - for (int i = 0; i < n; i++) { - Line x = Line.generate(); - out.println(x); - } - } - } else if ("--parse".equals(args[0])) { - try (BufferedReader in = Files.newReader(new File(args[1]), Charsets.UTF_8)){ - String line = in.readLine(); - while (line != null) { - v.assign(0); - Line x = new Line(line); - for (int i = 0; i < FIELDS; i++) { - s[i].add(x.getDouble(i)); - encoder[i].addToVector(x.get(i), v); - } - line = in.readLine(); - } - } - String separator = ""; - for (int i = 0; i < FIELDS; i++) { - System.out.printf("%s%.3f", separator, s[i].getMean()); - separator = ","; - } - } else if ("--fast".equals(args[0])) { - try (FastLineReader in = new FastLineReader(new FileInputStream(args[1]))){ - FastLine line = in.read(); - while (line != null) { - v.assign(0); - for (int i = 0; i < FIELDS; i++) { - double z = line.getDouble(i); - s[i].add(z); - encoder[i].addToVector((byte[]) null, z, v); - } - line = in.read(); - } - } - - String separator = ""; - for (int i = 0; i < FIELDS; i++) { - System.out.printf("%s%.3f", separator, s[i].getMean()); - separator = ","; - } - } - System.out.printf("\nElapsed time = %.3f%n", (System.currentTimeMillis() - t0) / 1000.0); - } - - - private static final class Line { - private static final Splitter ON_TABS = Splitter.on(SEPARATOR_CHAR).trimResults(); - public static final Joiner WITH_COMMAS = Joiner.on(SEPARATOR_CHAR); - - public static final Random RAND = RandomUtils.getRandom(); - - private final List<String> data; - - private Line(CharSequence line) { - data = Lists.newArrayList(ON_TABS.split(line)); - } - - private Line() { - data = new ArrayList<>(); - } - - public double getDouble(int field) { - return Double.parseDouble(data.get(field)); - } - - /** - * Generate a random line with 20 fields each with integer values. - * - * @return A new line with data. - */ - public static Line generate() { - Line r = new Line(); - for (int i = 0; i < FIELDS; i++) { - double mean = ((i + 1) * 257) % 50 + 1; - r.data.add(Integer.toString(randomValue(mean))); - } - return r; - } - - /** - * Returns a random exponentially distributed integer with a particular mean value. This is - * just a way to create more small numbers than big numbers. - * - * @param mean mean of the distribution - * @return random exponentially distributed integer with the specific mean - */ - private static int randomValue(double mean) { - return (int) (-mean * Math.log1p(-RAND.nextDouble())); - } - - @Override - public String toString() { - return WITH_COMMAS.join(data); - } - - public String get(int field) { - return data.get(field); - } - } - - private static final class FastLine { - - private final ByteBuffer base; - private final IntArrayList start = new IntArrayList(); - private final IntArrayList length = new IntArrayList(); - - private FastLine(ByteBuffer base) { - this.base = base; - } - - public static FastLine read(ByteBuffer buf) { - FastLine r = new FastLine(buf); - r.start.add(buf.position()); - int offset = buf.position(); - while (offset < buf.limit()) { - int ch = buf.get(); - offset = buf.position(); - switch (ch) { - case '\n': - r.length.add(offset - r.start.get(r.length.size()) - 1); - return r; - case SEPARATOR_CHAR: - r.length.add(offset - r.start.get(r.length.size()) - 1); - r.start.add(offset); - break; - default: - // nothing to do for now - } - } - throw new IllegalArgumentException("Not enough bytes in buffer"); - } - - public double getDouble(int field) { - int offset = start.get(field); - int size = length.get(field); - switch (size) { - case 1: - return base.get(offset) - '0'; - case 2: - return (base.get(offset) - '0') * 10 + base.get(offset + 1) - '0'; - default: - double r = 0; - for (int i = 0; i < size; i++) { - r = 10 * r + base.get(offset + i) - '0'; - } - return r; - } - } - } - - private static final class FastLineReader implements Closeable { - private final InputStream in; - private final ByteBuffer buf = ByteBuffer.allocate(100000); - - private FastLineReader(InputStream in) throws IOException { - this.in = in; - buf.limit(0); - fillBuffer(); - } - - public FastLine read() throws IOException { - fillBuffer(); - if (buf.remaining() > 0) { - return FastLine.read(buf); - } else { - return null; - } - } - - private void fillBuffer() throws IOException { - if (buf.remaining() < 10000) { - buf.compact(); - int n = in.read(buf.array(), buf.position(), buf.remaining()); - if (n == -1) { - buf.flip(); - } else { - buf.limit(buf.position() + n); - buf.position(0); - } - } - } - - @Override - public void close() { - try { - Closeables.close(in, true); - } catch (IOException e) { - log.error(e.getMessage(), e); - } - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java deleted file mode 100644 index 074f774..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.fs.PathFilter; -import org.apache.hadoop.io.Text; -import org.apache.mahout.classifier.ClassifierResult; -import org.apache.mahout.classifier.ResultAnalyzer; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; -import org.apache.mahout.math.Vector; -import org.apache.mahout.math.VectorWritable; -import org.apache.mahout.vectorizer.encoders.Dictionary; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; - -/** - * Run the ASF email, as trained by TrainASFEmail - */ -public final class TestASFEmail { - - private String inputFile; - private String modelFile; - - private TestASFEmail() {} - - public static void main(String[] args) throws IOException { - TestASFEmail runner = new TestASFEmail(); - if (runner.parseArgs(args)) { - runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); - } - } - - public void run(PrintWriter output) throws IOException { - - File base = new File(inputFile); - //contains the best model - OnlineLogisticRegression classifier = - ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class); - - - Dictionary asfDictionary = new Dictionary(); - Configuration conf = new Configuration(); - PathFilter testFilter = new PathFilter() { - @Override - public boolean accept(Path path) { - return path.getName().contains("test"); - } - }; - SequenceFileDirIterator<Text, VectorWritable> iter = - new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter, - null, true, conf); - - long numItems = 0; - while (iter.hasNext()) { - Pair<Text, VectorWritable> next = iter.next(); - asfDictionary.intern(next.getFirst().toString()); - numItems++; - } - - System.out.println(numItems + " test files"); - ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT"); - iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, testFilter, - null, true, conf); - while (iter.hasNext()) { - Pair<Text, VectorWritable> next = iter.next(); - String ng = next.getFirst().toString(); - - int actual = asfDictionary.intern(ng); - Vector result = classifier.classifyFull(next.getSecond().get()); - int cat = result.maxValueIndex(); - double score = result.maxValue(); - double ll = classifier.logLikelihood(actual, next.getSecond().get()); - ClassifierResult cr = new ClassifierResult(asfDictionary.values().get(cat), score, ll); - ra.addInstance(asfDictionary.values().get(actual), cr); - - } - output.println(ra); - } - - boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help").withDescription("print this list").create(); - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option inputFileOption = builder.withLongName("input") - .withRequired(true) - .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) - .withDescription("where to get training data") - .create(); - - Option modelFileOption = builder.withLongName("model") - .withRequired(true) - .withArgument(argumentBuilder.withName("model").withMaximum(1).create()) - .withDescription("where to get a model") - .create(); - - Group normalArgs = new GroupBuilder() - .withOption(help) - .withOption(inputFileOption) - .withOption(modelFileOption) - .create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); - CommandLine cmdLine = parser.parseAndHelp(args); - - if (cmdLine == null) { - return false; - } - - inputFile = (String) cmdLine.getValue(inputFileOption); - modelFile = (String) cmdLine.getValue(modelFileOption); - return true; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java deleted file mode 100644 index f0316e9..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.commons.io.Charsets; -import org.apache.mahout.classifier.ClassifierResult; -import org.apache.mahout.classifier.NewsgroupHelper; -import org.apache.mahout.classifier.ResultAnalyzer; -import org.apache.mahout.math.Vector; -import org.apache.mahout.vectorizer.encoders.Dictionary; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** - * Run the 20 news groups test data through SGD, as trained by {@link org.apache.mahout.classifier.sgd.TrainNewsGroups}. - */ -public final class TestNewsGroups { - - private String inputFile; - private String modelFile; - - private TestNewsGroups() { - } - - public static void main(String[] args) throws IOException { - TestNewsGroups runner = new TestNewsGroups(); - if (runner.parseArgs(args)) { - runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); - } - } - - public void run(PrintWriter output) throws IOException { - - File base = new File(inputFile); - //contains the best model - OnlineLogisticRegression classifier = - ModelSerializer.readBinary(new FileInputStream(modelFile), OnlineLogisticRegression.class); - - Dictionary newsGroups = new Dictionary(); - Multiset<String> overallCounts = HashMultiset.create(); - - List<File> files = new ArrayList<>(); - for (File newsgroup : base.listFiles()) { - if (newsgroup.isDirectory()) { - newsGroups.intern(newsgroup.getName()); - files.addAll(Arrays.asList(newsgroup.listFiles())); - } - } - System.out.println(files.size() + " test files"); - ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT"); - for (File file : files) { - String ng = file.getParentFile().getName(); - - int actual = newsGroups.intern(ng); - NewsgroupHelper helper = new NewsgroupHelper(); - //no leak type ensures this is a normal vector - Vector input = helper.encodeFeatureVector(file, actual, 0, overallCounts); - Vector result = classifier.classifyFull(input); - int cat = result.maxValueIndex(); - double score = result.maxValue(); - double ll = classifier.logLikelihood(actual, input); - ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), score, ll); - ra.addInstance(newsGroups.values().get(actual), cr); - - } - output.println(ra); - } - - boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help").withDescription("print this list").create(); - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option inputFileOption = builder.withLongName("input") - .withRequired(true) - .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) - .withDescription("where to get training data") - .create(); - - Option modelFileOption = builder.withLongName("model") - .withRequired(true) - .withArgument(argumentBuilder.withName("model").withMaximum(1).create()) - .withDescription("where to get a model") - .create(); - - Group normalArgs = new GroupBuilder() - .withOption(help) - .withOption(inputFileOption) - .withOption(modelFileOption) - .create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); - CommandLine cmdLine = parser.parseAndHelp(args); - - if (cmdLine == null) { - return false; - } - - inputFile = (String) cmdLine.getValue(inputFileOption); - modelFile = (String) cmdLine.getValue(modelFileOption); - return true; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java deleted file mode 100644 index e681f92..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.collect.HashMultiset; -import com.google.common.collect.Multiset; -import com.google.common.collect.Ordering; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.fs.PathFilter; -import org.apache.hadoop.io.Text; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; -import org.apache.mahout.ep.State; -import org.apache.mahout.math.VectorWritable; -import org.apache.mahout.vectorizer.encoders.Dictionary; - -import java.io.File; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -public final class TrainASFEmail extends AbstractJob { - - private TrainASFEmail() { - } - - @Override - public int run(String[] args) throws Exception { - addInputOption(); - addOutputOption(); - addOption("categories", "nc", "The number of categories to train on", true); - addOption("cardinality", "c", "The size of the vectors to use", "100000"); - addOption("threads", "t", "The number of threads to use in the learner", "20"); - addOption("poolSize", "p", "The number of CrossFoldLearners to use in the AdaptiveLogisticRegression. " - + "Higher values require more memory.", "5"); - if (parseArguments(args) == null) { - return -1; - } - - File base = new File(getInputPath().toString()); - - Multiset<String> overallCounts = HashMultiset.create(); - File output = new File(getOutputPath().toString()); - output.mkdirs(); - int numCats = Integer.parseInt(getOption("categories")); - int cardinality = Integer.parseInt(getOption("cardinality", "100000")); - int threadCount = Integer.parseInt(getOption("threads", "20")); - int poolSize = Integer.parseInt(getOption("poolSize", "5")); - Dictionary asfDictionary = new Dictionary(); - AdaptiveLogisticRegression learningAlgorithm = - new AdaptiveLogisticRegression(numCats, cardinality, new L1(), threadCount, poolSize); - learningAlgorithm.setInterval(800); - learningAlgorithm.setAveragingWindow(500); - - //We ran seq2encoded and split input already, so let's just build up the dictionary - Configuration conf = new Configuration(); - PathFilter trainFilter = new PathFilter() { - @Override - public boolean accept(Path path) { - return path.getName().contains("training"); - } - }; - SequenceFileDirIterator<Text, VectorWritable> iter = - new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, null, true, conf); - long numItems = 0; - while (iter.hasNext()) { - Pair<Text, VectorWritable> next = iter.next(); - asfDictionary.intern(next.getFirst().toString()); - numItems++; - } - - System.out.println(numItems + " training files"); - - SGDInfo info = new SGDInfo(); - - iter = new SequenceFileDirIterator<>(new Path(base.toString()), PathType.LIST, trainFilter, - null, true, conf); - int k = 0; - while (iter.hasNext()) { - Pair<Text, VectorWritable> next = iter.next(); - String ng = next.getFirst().toString(); - int actual = asfDictionary.intern(ng); - //we already have encoded - learningAlgorithm.train(actual, next.getSecond().get()); - k++; - State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = learningAlgorithm.getBest(); - - SGDHelper.analyzeState(info, 0, k, best); - } - learningAlgorithm.close(); - //TODO: how to dissection since we aren't processing the files here - //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, overallCounts); - System.out.println("exiting main, writing model to " + output); - - ModelSerializer.writeBinary(output + "/asf.model", - learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0)); - - List<Integer> counts = new ArrayList<>(); - System.out.println("Word counts"); - for (String count : overallCounts.elementSet()) { - counts.add(overallCounts.count(count)); - } - Collections.sort(counts, Ordering.natural().reverse()); - k = 0; - for (Integer count : counts) { - System.out.println(k + "\t" + count); - k++; - if (k > 1000) { - break; - } - } - return 0; - } - - public static void main(String[] args) throws Exception { - TrainASFEmail trainer = new TrainASFEmail(); - trainer.run(args); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java b/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java deleted file mode 100644 index defb5b9..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainAdaptiveLogistic.java +++ /dev/null @@ -1,377 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sgd; - -import com.google.common.io.Resources; -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.commons.cli2.util.HelpFormatter; -import org.apache.commons.io.Charsets; -import org.apache.mahout.classifier.sgd.AdaptiveLogisticRegression.Wrapper; -import org.apache.mahout.ep.State; -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; - -import java.io.BufferedReader; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.InputStreamReader; -import java.io.OutputStream; -import java.io.OutputStreamWriter; -import java.io.PrintWriter; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; - -public final class TrainAdaptiveLogistic { - - private static String inputFile; - private static String outputFile; - private static AdaptiveLogisticModelParameters lmp; - private static int passes; - private static boolean showperf; - private static int skipperfnum = 99; - private static AdaptiveLogisticRegression model; - - private TrainAdaptiveLogistic() { - } - - public static void main(String[] args) throws Exception { - mainToOutput(args, new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); - } - - static void mainToOutput(String[] args, PrintWriter output) throws Exception { - if (parseArgs(args)) { - - CsvRecordFactory csv = lmp.getCsvRecordFactory(); - model = lmp.createAdaptiveLogisticRegression(); - State<Wrapper, CrossFoldLearner> best; - CrossFoldLearner learner = null; - - int k = 0; - for (int pass = 0; pass < passes; pass++) { - BufferedReader in = open(inputFile); - - // read variable names - csv.firstLine(in.readLine()); - - String line = in.readLine(); - while (line != null) { - // for each new line, get target and predictors - Vector input = new RandomAccessSparseVector(lmp.getNumFeatures()); - int targetValue = csv.processLine(line, input); - - // update model - model.train(targetValue, input); - k++; - - if (showperf && (k % (skipperfnum + 1) == 0)) { - - best = model.getBest(); - if (best != null) { - learner = best.getPayload().getLearner(); - } - if (learner != null) { - double averageCorrect = learner.percentCorrect(); - double averageLL = learner.logLikelihood(); - output.printf("%d\t%.3f\t%.2f%n", - k, averageLL, averageCorrect * 100); - } else { - output.printf(Locale.ENGLISH, - "%10d %2d %s%n", k, targetValue, - "AdaptiveLogisticRegression has not found a good model ......"); - } - } - line = in.readLine(); - } - in.close(); - } - - best = model.getBest(); - if (best != null) { - learner = best.getPayload().getLearner(); - } - if (learner == null) { - output.println("AdaptiveLogisticRegression has failed to train a model."); - return; - } - - try (OutputStream modelOutput = new FileOutputStream(outputFile)) { - lmp.saveTo(modelOutput); - } - - OnlineLogisticRegression lr = learner.getModels().get(0); - output.println(lmp.getNumFeatures()); - output.println(lmp.getTargetVariable() + " ~ "); - String sep = ""; - for (String v : csv.getTraceDictionary().keySet()) { - double weight = predictorWeight(lr, 0, csv, v); - if (weight != 0) { - output.printf(Locale.ENGLISH, "%s%.3f*%s", sep, weight, v); - sep = " + "; - } - } - output.printf("%n"); - - for (int row = 0; row < lr.getBeta().numRows(); row++) { - for (String key : csv.getTraceDictionary().keySet()) { - double weight = predictorWeight(lr, row, csv, key); - if (weight != 0) { - output.printf(Locale.ENGLISH, "%20s %.5f%n", key, weight); - } - } - for (int column = 0; column < lr.getBeta().numCols(); column++) { - output.printf(Locale.ENGLISH, "%15.9f ", lr.getBeta().get(row, column)); - } - output.println(); - } - } - - } - - private static double predictorWeight(OnlineLogisticRegression lr, int row, RecordFactory csv, String predictor) { - double weight = 0; - for (Integer column : csv.getTraceDictionary().get(predictor)) { - weight += lr.getBeta().get(row, column); - } - return weight; - } - - private static boolean parseArgs(String[] args) { - DefaultOptionBuilder builder = new DefaultOptionBuilder(); - - Option help = builder.withLongName("help") - .withDescription("print this list").create(); - - Option quiet = builder.withLongName("quiet") - .withDescription("be extra quiet").create(); - - - ArgumentBuilder argumentBuilder = new ArgumentBuilder(); - Option showperf = builder - .withLongName("showperf") - .withDescription("output performance measures during training") - .create(); - - Option inputFile = builder - .withLongName("input") - .withRequired(true) - .withArgument( - argumentBuilder.withName("input").withMaximum(1) - .create()) - .withDescription("where to get training data").create(); - - Option outputFile = builder - .withLongName("output") - .withRequired(true) - .withArgument( - argumentBuilder.withName("output").withMaximum(1) - .create()) - .withDescription("where to write the model content").create(); - - Option threads = builder.withLongName("threads") - .withArgument( - argumentBuilder.withName("threads").withDefault("4").create()) - .withDescription("the number of threads AdaptiveLogisticRegression uses") - .create(); - - - Option predictors = builder.withLongName("predictors") - .withRequired(true) - .withArgument(argumentBuilder.withName("predictors").create()) - .withDescription("a list of predictor variables").create(); - - Option types = builder - .withLongName("types") - .withRequired(true) - .withArgument(argumentBuilder.withName("types").create()) - .withDescription( - "a list of predictor variable types (numeric, word, or text)") - .create(); - - Option target = builder - .withLongName("target") - .withDescription("the name of the target variable") - .withRequired(true) - .withArgument( - argumentBuilder.withName("target").withMaximum(1) - .create()) - .create(); - - Option targetCategories = builder - .withLongName("categories") - .withDescription("the number of target categories to be considered") - .withRequired(true) - .withArgument(argumentBuilder.withName("categories").withMaximum(1).create()) - .create(); - - - Option features = builder - .withLongName("features") - .withDescription("the number of internal hashed features to use") - .withArgument( - argumentBuilder.withName("numFeatures") - .withDefault("1000").withMaximum(1).create()) - .create(); - - Option passes = builder - .withLongName("passes") - .withDescription("the number of times to pass over the input data") - .withArgument( - argumentBuilder.withName("passes").withDefault("2") - .withMaximum(1).create()) - .create(); - - Option interval = builder.withLongName("interval") - .withArgument( - argumentBuilder.withName("interval").withDefault("500").create()) - .withDescription("the interval property of AdaptiveLogisticRegression") - .create(); - - Option window = builder.withLongName("window") - .withArgument( - argumentBuilder.withName("window").withDefault("800").create()) - .withDescription("the average propery of AdaptiveLogisticRegression") - .create(); - - Option skipperfnum = builder.withLongName("skipperfnum") - .withArgument( - argumentBuilder.withName("skipperfnum").withDefault("99").create()) - .withDescription("show performance measures every (skipperfnum + 1) rows") - .create(); - - Option prior = builder.withLongName("prior") - .withArgument( - argumentBuilder.withName("prior").withDefault("L1").create()) - .withDescription("the prior algorithm to use: L1, L2, ebp, tp, up") - .create(); - - Option priorOption = builder.withLongName("prioroption") - .withArgument( - argumentBuilder.withName("prioroption").create()) - .withDescription("constructor parameter for ElasticBandPrior and TPrior") - .create(); - - Option auc = builder.withLongName("auc") - .withArgument( - argumentBuilder.withName("auc").withDefault("global").create()) - .withDescription("the auc to use: global or grouped") - .create(); - - - - Group normalArgs = new GroupBuilder().withOption(help) - .withOption(quiet).withOption(inputFile).withOption(outputFile) - .withOption(target).withOption(targetCategories) - .withOption(predictors).withOption(types).withOption(passes) - .withOption(interval).withOption(window).withOption(threads) - .withOption(prior).withOption(features).withOption(showperf) - .withOption(skipperfnum).withOption(priorOption).withOption(auc) - .create(); - - Parser parser = new Parser(); - parser.setHelpOption(help); - parser.setHelpTrigger("--help"); - parser.setGroup(normalArgs); - parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); - CommandLine cmdLine = parser.parseAndHelp(args); - - if (cmdLine == null) { - return false; - } - - TrainAdaptiveLogistic.inputFile = getStringArgument(cmdLine, inputFile); - TrainAdaptiveLogistic.outputFile = getStringArgument(cmdLine, - outputFile); - - List<String> typeList = new ArrayList<>(); - for (Object x : cmdLine.getValues(types)) { - typeList.add(x.toString()); - } - - List<String> predictorList = new ArrayList<>(); - for (Object x : cmdLine.getValues(predictors)) { - predictorList.add(x.toString()); - } - - lmp = new AdaptiveLogisticModelParameters(); - lmp.setTargetVariable(getStringArgument(cmdLine, target)); - lmp.setMaxTargetCategories(getIntegerArgument(cmdLine, targetCategories)); - lmp.setNumFeatures(getIntegerArgument(cmdLine, features)); - lmp.setInterval(getIntegerArgument(cmdLine, interval)); - lmp.setAverageWindow(getIntegerArgument(cmdLine, window)); - lmp.setThreads(getIntegerArgument(cmdLine, threads)); - lmp.setAuc(getStringArgument(cmdLine, auc)); - lmp.setPrior(getStringArgument(cmdLine, prior)); - if (cmdLine.getValue(priorOption) != null) { - lmp.setPriorOption(getDoubleArgument(cmdLine, priorOption)); - } - lmp.setTypeMap(predictorList, typeList); - TrainAdaptiveLogistic.showperf = getBooleanArgument(cmdLine, showperf); - TrainAdaptiveLogistic.skipperfnum = getIntegerArgument(cmdLine, skipperfnum); - TrainAdaptiveLogistic.passes = getIntegerArgument(cmdLine, passes); - - lmp.checkParameters(); - - return true; - } - - private static String getStringArgument(CommandLine cmdLine, - Option inputFile) { - return (String) cmdLine.getValue(inputFile); - } - - private static boolean getBooleanArgument(CommandLine cmdLine, Option option) { - return cmdLine.hasOption(option); - } - - private static int getIntegerArgument(CommandLine cmdLine, Option features) { - return Integer.parseInt((String) cmdLine.getValue(features)); - } - - private static double getDoubleArgument(CommandLine cmdLine, Option op) { - return Double.parseDouble((String) cmdLine.getValue(op)); - } - - public static AdaptiveLogisticRegression getModel() { - return model; - } - - public static LogisticModelParameters getParameters() { - return lmp; - } - - static BufferedReader open(String inputFile) throws IOException { - InputStream in; - try { - in = Resources.getResource(inputFile).openStream(); - } catch (IllegalArgumentException e) { - in = new FileInputStream(new File(inputFile)); - } - return new BufferedReader(new InputStreamReader(in, Charsets.UTF_8)); - } - -}
