Repository: incubator-samoa Updated Branches: refs/heads/master 4471fe4ae -> dbc3aab13
SAMOA-68: Saving true and predicted labels to file Fix #61 Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/dbc3aab1 Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/dbc3aab1 Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/dbc3aab1 Branch: refs/heads/master Commit: dbc3aab13d200ca0212518543db3f9f2f93e1751 Parents: 4471fe4 Author: Maciej Grzenda <[email protected]> Authored: Wed May 17 13:51:56 2017 +0200 Committer: Gianmarco De Francisci Morales <[email protected]> Committed: Wed Jul 5 10:50:29 2017 +0300 ---------------------------------------------------------------------- .../org/apache/samoa/apex/AlgosTestApex.java | 2 + ...BasicClassificationPerformanceEvaluator.java | 100 +++++++++++++++---- .../BasicRegressionPerformanceEvaluator.java | 46 +++++++-- .../samoa/evaluation/EvaluatorCVProcessor.java | 27 ++--- .../samoa/evaluation/EvaluatorProcessor.java | 72 ++++++++++++- .../F1ClassificationPerformanceEvaluator.java | 42 +++++++- .../samoa/evaluation/PerformanceEvaluator.java | 10 +- ...indowClassificationPerformanceEvaluator.java | 43 +++++++- .../java/org/apache/samoa/moa/core/Vote.java | 86 ++++++++++++++++ .../samoa/moa/evaluation/LearningCurve.java | 62 ++++++++++++ .../samoa/tasks/PrequentialEvaluation.java | 15 ++- .../test/java/org/apache/samoa/AlgosTest.java | 5 + .../test/java/org/apache/samoa/AlgosTest.java | 31 +++--- .../test/java/org/apache/samoa/TestParams.java | 50 ++++++++-- .../test/java/org/apache/samoa/TestUtils.java | 64 ++++++++---- .../test/java/org/apache/samoa/AlgosTest.java | 2 + 16 files changed, 569 insertions(+), 88 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java ---------------------------------------------------------------------- diff --git a/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java b/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java index 80d9449..7e0ca48 100644 --- a/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java +++ b/samoa-apex/src/test/java/org/apache/samoa/apex/AlgosTestApex.java @@ -35,6 +35,7 @@ public class AlgosTestApex { .samplingSize(20_000) .evaluationInstances(200_000) .classifiedInstances(200_000) + .labelSamplingSize(10l) .classificationsCorrect(55f) .kappaStat(0f) .kappaTempStat(0f) @@ -54,6 +55,7 @@ public class AlgosTestApex { .samplingSize(20_000) .evaluationInstances(180_000) .classifiedInstances(190_000) + .labelSamplingSize(10l) .classificationsCorrect(60f) .kappaStat(0f) .kappaTempStat(0f) http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java index 24abe3e..a77831a 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicClassificationPerformanceEvaluator.java @@ -1,5 +1,10 @@ package org.apache.samoa.evaluation; +import java.util.Arrays; +import java.util.List; + +import org.apache.samoa.instances.Attribute; + /* * #%L * SAMOA @@ -24,6 +29,7 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Utils; import org.apache.samoa.moa.AbstractMOAObject; import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.Vote; /** * Classification evaluator that performs basic incremental evaluation. @@ -32,11 +38,23 @@ import org.apache.samoa.moa.core.Measurement; * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) * @version $Revision: 7 $ */ -public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject implements - ClassificationPerformanceEvaluator { +public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject + implements ClassificationPerformanceEvaluator { private static final long serialVersionUID = 1L; + // the number of decimal places placed for double values in prediction file + // the value of 10 is used since some votes can be relatively small + public static final int DECIMAL_PLACES = 10; + + // the vote value to be used when a classifier made no vote for the class at + // all + public static final int NO_VOTE_FOR_CLASS = 0; + + // recent vote objects i.e. predicted, true classes and votes for individual + // classes + protected Vote[] votes; + protected double weightObserved; protected double weightCorrect; @@ -49,11 +67,17 @@ public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject i private double weightCorrectNoChangeClassifier; + protected double[] classVotes; + private int lastSeenClass; + private String instanceIdentifier; + + private Instance lastSeenInstance; @Override public void reset() { reset(this.numClasses); + votes = null; } public void reset(int numClasses) { @@ -68,10 +92,11 @@ public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject i this.weightCorrect = 0.0; this.weightCorrectNoChangeClassifier = 0.0; this.lastSeenClass = 0; + votes = null; } @Override - public void addResult(Instance inst, double[] classVotes) { + public void addResult(Instance inst, double[] classVotes, String instanceIdentifier) { double weight = inst.weight(); int trueClass = (int) inst.classValue(); if (weight > 0.0) { @@ -94,20 +119,60 @@ public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject i this.weightCorrectNoChangeClassifier += weight; } this.lastSeenClass = trueClass; + this.lastSeenInstance = inst; + this.instanceIdentifier = instanceIdentifier; + this.classVotes = classVotes; } @Override public Measurement[] getPerformanceMeasurements() { - return new Measurement[] { - new Measurement("classified instances", - getTotalWeightObserved()), - new Measurement("classifications correct (percent)", - getFractionCorrectlyClassified() * 100.0), - new Measurement("Kappa Statistic (percent)", - getKappaStatistic() * 100.0), - new Measurement("Kappa Temporal Statistic (percent)", - getKappaTemporalStatistic() * 100.0) - }; + return new Measurement[] { new Measurement("classified instances", getTotalWeightObserved()), + new Measurement("classifications correct (percent)", getFractionCorrectlyClassified() * 100.0), + new Measurement("Kappa Statistic (percent)", getKappaStatistic() * 100.0), + new Measurement("Kappa Temporal Statistic (percent)", getKappaTemporalStatistic() * 100.0) }; + + } + + /** + * This method is used to retrieve predictions and votes (for classification only) + * + * @return String This returns an array of predictions and votes objects. + */ + @Override + public Vote[] getPredictionVotes() { + Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute(); + double trueValue = this.lastSeenInstance.classValue(); + List<String> classAttributeValues = classAttribute.getAttributeValues(); + + int trueNominalIndex = (int) trueValue; + String trueNominalValue = classAttributeValues.get(trueNominalIndex); + + // initialise votes first time they are supposed to be used + if (votes == null) { + this.votes = new Vote[classAttributeValues.size() + 3]; + votes[0] = new Vote("instance number"); + votes[1] = new Vote("true class value"); + votes[2] = new Vote("predicted class value"); + + // create as many objects as the number of classes + for (int i = 0; i < classAttributeValues.size(); i++) { + votes[3 + i] = new Vote("votes_" + classAttributeValues.get(i)); + } + } + + // use/(re-use existing) vote objects + votes[0].setValue(this.instanceIdentifier); + votes[1].setValue(trueNominalValue); + votes[2].setValue(classAttributeValues.get(Utils.maxIndex(classVotes))); + for (int i = 0; i < classAttributeValues.size(); i++) { + if (i < classVotes.length) { + votes[3 + i].setValue(classVotes[i], this.DECIMAL_PLACES); + } else { + votes[3 + i].setValue(this.NO_VOTE_FOR_CLASS, 0); + } + } + + return votes; } @@ -116,8 +181,7 @@ public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject i } public double getFractionCorrectlyClassified() { - return this.weightObserved > 0.0 ? this.weightCorrect - / this.weightObserved : 0.0; + return this.weightObserved > 0.0 ? this.weightCorrect / this.weightObserved : 0.0; } public double getFractionIncorrectlyClassified() { @@ -129,8 +193,7 @@ public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject i double p0 = getFractionCorrectlyClassified(); double pc = 0.0; for (int i = 0; i < this.numClasses; i++) { - pc += (this.rowKappa[i] / this.weightObserved) - * (this.columnKappa[i] / this.weightObserved); + pc += (this.rowKappa[i] / this.weightObserved) * (this.columnKappa[i] / this.weightObserved); } return (p0 - pc) / (1.0 - pc); } else { @@ -151,7 +214,6 @@ public class BasicClassificationPerformanceEvaluator extends AbstractMOAObject i @Override public void getDescription(StringBuilder sb, int indent) { - Measurement.getMeasurementsDescription(getPerformanceMeasurements(), - sb, indent); + Measurement.getMeasurementsDescription(getPerformanceMeasurements(), sb, indent); } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java index ec48156..ab16904 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java @@ -1,5 +1,9 @@ package org.apache.samoa.evaluation; +import java.util.List; + +import org.apache.samoa.instances.Attribute; + /* * #%L * SAMOA @@ -21,8 +25,10 @@ package org.apache.samoa.evaluation; */ import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Utils; import org.apache.samoa.moa.AbstractMOAObject; import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.Vote; /** * Regression evaluator that performs basic incremental evaluation. @@ -35,6 +41,10 @@ public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject private static final long serialVersionUID = 1L; + // the number of decimal places placed for double values in prediction file + // the value of 10 is used since some predicted values can be relatively small + public static final int DECIMAL_PLACES = 10; + protected double weightObserved; protected double squareError; @@ -47,6 +57,10 @@ public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject protected double averageTargetError; + private String instanceIdentifier; + private Instance lastSeenInstance; + private double lastPredictedValue; + @Override public void reset() { this.weightObserved = 0.0; @@ -59,19 +73,23 @@ public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject } @Override - public void addResult(Instance inst, double[] prediction) { + public void addResult(Instance inst, double[] prediction, String instanceIdentifier) { double weight = inst.weight(); double classValue = inst.classValue(); if (weight > 0.0) { if (prediction.length > 0) { - double meanTarget = this.weightObserved != 0 ? - this.sumTarget / this.weightObserved : 0.0; + double meanTarget = this.weightObserved != 0 ? this.sumTarget / this.weightObserved : 0.0; this.squareError += (classValue - prediction[0]) * (classValue - prediction[0]); this.averageError += Math.abs(classValue - prediction[0]); this.squareTargetError += (classValue - meanTarget) * (classValue - meanTarget); this.averageTargetError += Math.abs(classValue - meanTarget); this.sumTarget += classValue; this.weightObserved += weight; + this.lastPredictedValue = prediction[0]; + this.lastSeenInstance = inst; + this.instanceIdentifier = instanceIdentifier; + } else { + this.lastPredictedValue = Double.NaN; } } } @@ -92,6 +110,22 @@ public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject }; } + /** + * This method is used to retrieve predictions + * + * @return String This returns an array of predictions and votes objects. + */ + @Override + public Vote[] getPredictionVotes() { + double trueValue = this.lastSeenInstance.classValue(); + return new Vote[] { + new Vote("instance number", + this.instanceIdentifier), + new Vote("true value", trueValue, this.DECIMAL_PLACES), + new Vote("predicted value", this.lastPredictedValue, this.DECIMAL_PLACES) + }; + } + public double getTotalWeightObserved() { return this.weightObserved; } @@ -123,12 +157,10 @@ public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject } private double getRelativeMeanError() { - return this.averageTargetError > 0 ? - this.averageError / this.averageTargetError : 0.0; + return this.averageTargetError > 0 ? this.averageError / this.averageTargetError : 0.0; } private double getRelativeSquareError() { - return Math.sqrt(this.squareTargetError > 0 ? - this.squareError / this.squareTargetError : 0.0); + return Math.sqrt(this.squareTargetError > 0 ? this.squareError / this.squareTargetError : 0.0); } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java index f282f0d..05d0a27 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java @@ -39,12 +39,11 @@ import java.util.concurrent.TimeUnit; public class EvaluatorCVProcessor implements Processor { /** - * - */ + * + */ private static final long serialVersionUID = -2778051819116753612L; - private static final Logger logger = - LoggerFactory.getLogger(EvaluatorCVProcessor.class); + private static final Logger logger = LoggerFactory.getLogger(EvaluatorCVProcessor.class); private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances"; @@ -90,7 +89,8 @@ public class EvaluatorCVProcessor implements Processor { addStatisticsForInstanceReceived(instanceIndex, result.getEvaluationIndex(), 1); - evaluators[result.getEvaluationIndex()].addResult(result.getInstance(), result.getClassVotes()); + evaluators[result.getEvaluationIndex()].addResult(result.getInstance(), result.getClassVotes(), + String.valueOf(instanceIndex)); if (hasAllVotesArrivedInstance(instanceIndex)) { totalCount += 1; @@ -110,8 +110,6 @@ public class EvaluatorCVProcessor implements Processor { } } - - return false; } @@ -122,6 +120,7 @@ public class EvaluatorCVProcessor implements Processor { int count = map.get(instanceIndex); return (count == this.foldNumber); } + protected void addStatisticsForInstanceReceived(int instanceIndex, int evaluationIndex, int add) { if (this.mapCountsforInstanceReceived == null) { this.mapCountsforInstanceReceived = new HashMap<>(); @@ -190,10 +189,10 @@ public class EvaluatorCVProcessor implements Processor { private void addMeasurement() { List<Measurement> measurements = new Vector<>(); - measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount )); + measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount)); Measurement[] finalMeasurements = getEvaluationMeasurements( - measurements.toArray(new Measurement[measurements.size()]), evaluators); + measurements.toArray(new Measurement[measurements.size()]), evaluators); LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements); learningCurve.insertEntry(learningEvaluation); @@ -220,7 +219,7 @@ public class EvaluatorCVProcessor implements Processor { long experimentEnd = System.nanoTime(); long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS); - logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount ); + logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount); if (immediateResultStream != null) { immediateResultStream.println("# COMPLETED"); @@ -257,7 +256,7 @@ public class EvaluatorCVProcessor implements Processor { return this; } - public Builder foldNumber(int foldNumber){ + public Builder foldNumber(int foldNumber) { this.foldNumber = foldNumber; return this; } @@ -267,7 +266,8 @@ public class EvaluatorCVProcessor implements Processor { } } - public Measurement[] getEvaluationMeasurements(Measurement[] modelMeasurements, PerformanceEvaluator[] subEvaluators) { + public Measurement[] getEvaluationMeasurements(Measurement[] modelMeasurements, + PerformanceEvaluator[] subEvaluators) { List<Measurement> measurementList = new LinkedList<Measurement>(); if (modelMeasurements != null) { measurementList.addAll(Arrays.asList(modelMeasurements)); @@ -280,7 +280,8 @@ public class EvaluatorCVProcessor implements Processor { subMeasurements.add(subEvaluator.getPerformanceMeasurements()); } } - Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][])); + Measurement[] avgMeasurements = Measurement + .averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][])); measurementList.addAll(Arrays.asList(avgMeasurements)); } return measurementList.toArray(new Measurement[measurementList.size()]); http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java index 6ec50dc..e78395a 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java @@ -33,6 +33,7 @@ import org.apache.samoa.core.ContentEvent; import org.apache.samoa.core.Processor; import org.apache.samoa.learners.ResultContentEvent; import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.Vote; import org.apache.samoa.moa.evaluation.LearningCurve; import org.apache.samoa.moa.evaluation.LearningEvaluation; import org.slf4j.Logger; @@ -41,20 +42,23 @@ import org.slf4j.LoggerFactory; public class EvaluatorProcessor implements Processor { /** - * - */ + * + */ private static final long serialVersionUID = -2778051819116753612L; - private static final Logger logger = - LoggerFactory.getLogger(EvaluatorProcessor.class); + private static final Logger logger = LoggerFactory.getLogger(EvaluatorProcessor.class); private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances"; private final PerformanceEvaluator evaluator; private final int samplingFrequency; private final File dumpFile; + private final File predictionFile; + private final int labelSamplingFrequency; private transient PrintStream immediateResultStream = null; + private transient PrintStream immediatePredictionStream = null; private transient boolean firstDump = true; + private transient boolean firstVoteDump = true; private long totalCount = 0; private long experimentStart = 0; @@ -68,6 +72,8 @@ public class EvaluatorProcessor implements Processor { this.evaluator = builder.evaluator; this.samplingFrequency = builder.samplingFrequency; this.dumpFile = builder.dumpFile; + this.predictionFile = builder.predictionFile; + this.labelSamplingFrequency = builder.labelSamplingFrequency; } @Override @@ -84,12 +90,18 @@ public class EvaluatorProcessor implements Processor { this.addMeasurement(); } + //adding a vote - true class value, predicted class value and for classification - votes + if ((immediatePredictionStream != null) && (totalCount > 0) && (totalCount % labelSamplingFrequency) == 0) { + this.addVote(); + } + if (result.isLastEvent()) { this.concludeMeasurement(); return true; } - evaluator.addResult(result.getInstance(), result.getClassVotes()); + String instanceIndex = String.valueOf(result.getInstanceIndex()); + evaluator.addResult(result.getInstance(), result.getClassVotes(), instanceIndex); totalCount += 1; if (totalCount == 1) { @@ -125,7 +137,20 @@ public class EvaluatorProcessor implements Processor { } } + if (this.predictionFile != null) { + try { + this.immediatePredictionStream = new PrintStream(new FileOutputStream(predictionFile), true); + } catch (FileNotFoundException e) { + this.immediatePredictionStream = null; + logger.error("File not found exception for {}:{}", this.predictionFile.getAbsolutePath(), e.toString()); + } catch (Exception e) { + this.immediatePredictionStream = null; + logger.error("Exception when creating {}:{}", this.predictionFile.getAbsolutePath(), e.toString()); + } + } + this.firstDump = true; + this.firstVoteDump = true; } @Override @@ -179,6 +204,26 @@ public class EvaluatorProcessor implements Processor { } } + /** + * This method is used to create one line of a text file containing predictions and votes (for classification only). + * In case, this is the first line a header line is also added + */ + private void addVote() { + Vote[] finalVotes = evaluator.getPredictionVotes(); + learningCurve.setVote(finalVotes); + logger.debug("evaluator id = {}", this.id); + + if (immediatePredictionStream != null) { + if (firstVoteDump) { + immediatePredictionStream.println(learningCurve.voteHeaderToString()); + firstVoteDump = false; + } + + immediatePredictionStream.println(learningCurve.voteEntryToString()); + immediatePredictionStream.flush(); + } + } + private void concludeMeasurement() { logger.info("last event is received!"); logger.info("total count: {}", this.totalCount); @@ -192,6 +237,9 @@ public class EvaluatorProcessor implements Processor { if (immediateResultStream != null) { immediateResultStream.println("# COMPLETED"); + // + immediateResultStream + .println("# Total evaluation time: " + totalExperimentTime + " seconds for " + totalCount + " instances"); immediateResultStream.flush(); } // logger.info("average throughput rate: {} instances/seconds", @@ -203,6 +251,8 @@ public class EvaluatorProcessor implements Processor { private final PerformanceEvaluator evaluator; private int samplingFrequency = 100000; private File dumpFile = null; + private File predictionFile = null; + private int labelSamplingFrequency = 1; public Builder(PerformanceEvaluator evaluator) { this.evaluator = evaluator; @@ -212,6 +262,8 @@ public class EvaluatorProcessor implements Processor { this.evaluator = oldProcessor.evaluator; this.samplingFrequency = oldProcessor.samplingFrequency; this.dumpFile = oldProcessor.dumpFile; + this.predictionFile = oldProcessor.predictionFile; + this.labelSamplingFrequency = oldProcessor.labelSamplingFrequency; } public Builder samplingFrequency(int samplingFrequency) { @@ -224,6 +276,16 @@ public class EvaluatorProcessor implements Processor { return this; } + public Builder predictionFile(File file) { + this.predictionFile = file; + return this; + } + + public Builder labelSamplingFrequency(int samplingFrequency) { + this.labelSamplingFrequency = samplingFrequency; + return this; + } + public EvaluatorProcessor build() { return new EvaluatorProcessor(this); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java index 89e74be..d54296d 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/F1ClassificationPerformanceEvaluator.java @@ -1,5 +1,7 @@ package org.apache.samoa.evaluation; +import org.apache.samoa.instances.Attribute; + /* * #%L * SAMOA @@ -25,6 +27,7 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Utils; import org.apache.samoa.moa.AbstractMOAObject; import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.Vote; import java.util.Collections; import java.util.List; @@ -44,7 +47,10 @@ public class F1ClassificationPerformanceEvaluator extends AbstractMOAObject impl protected long[] falsePos; protected long[] trueNeg; protected long[] falseNeg; - + private String instanceIdentifier; + private Instance lastSeenInstance; + protected double[] classVotes; + @Override public void reset() { reset(this.numClasses); @@ -67,7 +73,7 @@ public class F1ClassificationPerformanceEvaluator extends AbstractMOAObject impl } @Override - public void addResult(Instance inst, double[] classVotes) { + public void addResult(Instance inst, double[] classVotes, String instanceIndex) { if (numClasses==-1) reset(inst.numClasses()); int trueClass = (int) inst.classValue(); this.support[trueClass] += 1; @@ -95,6 +101,38 @@ public class F1ClassificationPerformanceEvaluator extends AbstractMOAObject impl Collections.addAll(measurements, getF1Measurements()); return measurements.toArray(new Measurement[measurements.size()]); } + + /** + * This method is used to retrieve predictions and votes (for classification only) + * + * @return String This returns an array of predictions and votes objects. + */ + @Override + public Vote[] getPredictionVotes() { + Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute(); + double trueValue = this.lastSeenInstance.classValue(); + List<String> classAttributeValues = classAttribute.getAttributeValues(); + + int trueNominalIndex = (int) trueValue; + String trueNominalValue = classAttributeValues.get(trueNominalIndex); + + Vote[] votes = new Vote[classVotes.length + 3]; + votes[0] = new Vote("instance number", + this.instanceIdentifier); + votes[1] = new Vote("true class value", + trueNominalValue); + votes[2] = new Vote("predicted class value", + classAttributeValues.get(Utils.maxIndex(classVotes))); + + for (int i = 0; i < classAttributeValues.size(); i++) { + if (i < classVotes.length) { + votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), classVotes[i]); + } else { + votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), 0); + } + } + return votes; + } private Measurement[] getSupportMeasurements() { Measurement[] measurements = new Measurement[this.numClasses]; http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java index 0bd2450..c4c4a0b 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java @@ -23,6 +23,7 @@ package org.apache.samoa.evaluation; import org.apache.samoa.instances.Instance; import org.apache.samoa.moa.MOAObject; import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.Vote; /** * Interface implemented by learner evaluators to monitor the results of the learning process. @@ -47,7 +48,7 @@ public interface PerformanceEvaluator extends MOAObject { * an array containing the estimated membership probabilities of the test instance in each class * @return an array of measurements monitored in this evaluator */ - public void addResult(Instance inst, double[] classVotes); + public void addResult(Instance inst, double[] classVotes, String instanceIdentifier); /** * Gets the current measurements monitored by this evaluator. @@ -55,4 +56,11 @@ public interface PerformanceEvaluator extends MOAObject { * @return an array of measurements monitored by this evaluator */ public Measurement[] getPerformanceMeasurements(); + + /** + * Gets the current votes monitored by this evaluator. + * + * @return an array of votes monitored by this evaluator + */ + public Vote[] getPredictionVotes(); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java index c428a7f..6ea40ed 100644 --- a/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java @@ -1,5 +1,9 @@ package org.apache.samoa.evaluation; +import java.util.List; + +import org.apache.samoa.instances.Attribute; + /* * #%L * SAMOA @@ -24,6 +28,7 @@ import org.apache.samoa.instances.Instance; import org.apache.samoa.instances.Utils; import org.apache.samoa.moa.AbstractMOAObject; import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.Vote; import com.github.javacliparser.IntOption; @@ -59,6 +64,10 @@ public class WindowClassificationPerformanceEvaluator extends AbstractMOAObject protected int numClasses; + private String instanceIdentifier; + private Instance lastSeenInstance; + protected double[] classVotes; + public class Estimator { protected double[] window; @@ -127,7 +136,7 @@ public class WindowClassificationPerformanceEvaluator extends AbstractMOAObject } @Override - public void addResult(Instance inst, double[] classVotes) { + public void addResult(Instance inst, double[] classVotes, String instanceIndex) { double weight = inst.weight(); int trueClass = (int) inst.classValue(); if (weight > 0.0) { @@ -172,6 +181,38 @@ public class WindowClassificationPerformanceEvaluator extends AbstractMOAObject } + /** + * This method is used to retrieve predictions and votes (for classification only) + * + * @return String This returns an array of predictions and votes objects. + */ + @Override + public Vote[] getPredictionVotes() { + Attribute classAttribute = this.lastSeenInstance.dataset().classAttribute(); + double trueValue = this.lastSeenInstance.classValue(); + List<String> classAttributeValues = classAttribute.getAttributeValues(); + + int trueNominalIndex = (int) trueValue; + String trueNominalValue = classAttributeValues.get(trueNominalIndex); + + Vote[] votes = new Vote[classVotes.length + 3]; + votes[0] = new Vote("instance number", + this.instanceIdentifier); + votes[1] = new Vote("true class value", + trueNominalValue); + votes[2] = new Vote("predicted class value", + classAttributeValues.get(Utils.maxIndex(classVotes))); + + for (int i = 0; i < classAttributeValues.size(); i++) { + if (i < classVotes.length) { + votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), classVotes[i]); + } else { + votes[2 + i] = new Vote("votes_" + classAttributeValues.get(i), 0); + } + } + return votes; + } + public double getTotalWeightObserved() { return this.weightObserved.total(); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/moa/core/Vote.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/core/Vote.java b/samoa-api/src/main/java/org/apache/samoa/moa/core/Vote.java new file mode 100644 index 0000000..24ea3f3 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/core/Vote.java @@ -0,0 +1,86 @@ +package org.apache.samoa.moa.core; + +import java.io.Serializable; + +/* + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. See the License for the specific + * language governing permissions and limitations under the + * License. + */ + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +/** + * Class for storing votes. + * + */ +public class Vote implements Serializable { + + private static final long serialVersionUID = 1L; + + protected String name; + protected String value; + + public Vote(String name) { + this.name = name; + } + + public Vote(String name, String value) { + this.name = name; + this.value = value; + } + + public Vote(String name, double value) { + this(name, value, 3); + } + + public Vote(String name, double value, int fractionDigits) { + this(name); + setValue(value, fractionDigits); + } + + public String getName() { + return this.name; + } + + public String getValue() { + return this.value; + } + + public void setValue(String value) { + this.value = value; + } + + public void setValue(double value, int fractionDigits) { + // rely on dot as a decimal separator not to confuse CSV parsers + this.value = String.format(Locale.US, "%." + String.valueOf(fractionDigits) + "f", value); + } + + public static void getVotesDescription(Vote[] votes, + StringBuilder out, int indent) { + if (votes.length > 0) { + StringUtils.appendIndented(out, indent, votes[0].toString()); + for (int i = 1; i < votes.length; i++) { + StringUtils.appendNewlineIndented(out, indent, votes[i].toString()); + } + } + } + + public void getDescription(StringBuilder sb, int indent) { + sb.append(getName()); + sb.append(" = "); + sb.append(this.value); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java b/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java index 427e01d..dcc4f50 100644 --- a/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java +++ b/samoa-api/src/main/java/org/apache/samoa/moa/evaluation/LearningCurve.java @@ -27,6 +27,7 @@ import org.apache.samoa.moa.AbstractMOAObject; import org.apache.samoa.moa.core.DoubleVector; import org.apache.samoa.moa.core.Measurement; import org.apache.samoa.moa.core.StringUtils; +import org.apache.samoa.moa.core.Vote; /** * Class that stores and keeps the history of evaluation measurements. @@ -40,8 +41,12 @@ public class LearningCurve extends AbstractMOAObject { protected List<String> measurementNames = new ArrayList<String>(); + protected List<String> voteNames = new ArrayList<String>(); + protected List<double[]> measurementValues = new ArrayList<double[]>(); + protected List<String> voteValues = new ArrayList<String>(); + public LearningCurve(String orderingMeasurementName) { this.measurementNames.add(orderingMeasurementName); } @@ -129,4 +134,61 @@ public class LearningCurve extends AbstractMOAObject { public String getMeasurementName(int measurementIndex) { return this.measurementNames.get(measurementIndex); } + + protected int addVoteName(String name) { + int index = this.voteNames.indexOf(name); + if (index < 0) { + index = this.voteNames.size(); + this.voteNames.add(name); + } + return index; + } + + public void setVote(Vote[] votes) { + this.voteValues.clear(); + for (Vote vote : votes) { + voteValues.add(addVoteName(vote.getName()), vote.getValue()); + } + } + + /** + * This method is used to set generate header line of a text file containing predictions and votes (for classification + * only) + * + * @return String This returns the text of the header of a file containing predictions and votes. + */ + public String voteHeaderToString() { + StringBuilder sb = new StringBuilder(); + boolean first = true; + for (String name : this.voteNames) { + if (!first) { + sb.append(','); + } else { + first = false; + } + sb.append(name); + } + return sb.toString(); + } + + /** + * This method is used to set generate one body line of a text file containing predictions and votes (for + * classification only) + * + * @return String This returns the text of one line of a file containing predictions and votes. + */ + public String voteEntryToString() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < this.voteNames.size(); i++) { + if (i > 0) { + sb.append(','); + } + if ((i >= voteValues.size())) { + sb.append('?'); + } else { + sb.append(voteValues.get(i)); + } + } + return sb.toString(); + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java index 001622b..dab505b 100644 --- a/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java +++ b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java @@ -84,6 +84,12 @@ public class PrequentialEvaluation implements Task, Configurable { "How many instances between samples of the learning performance.", 100000, 0, Integer.MAX_VALUE); + // The frequency of saving model output e.g. predicted class and votes made for individual classes to a file + // The name of the actual file to which model output will be saved is defined through resultFileOption + public IntOption labelSampleFrequencyOption = new IntOption("labelSampleFrequency", 'h', + "How many instances between samples of predicted labels and votes.", 1, + 0, Integer.MAX_VALUE); + public StringOption evaluationNameOption = new StringOption("evaluationName", 'n', "Identifier of the evaluation", "Prequential_" + new SimpleDateFormat("yyyyMMddHHmmss").format(new Date())); @@ -91,6 +97,11 @@ public class PrequentialEvaluation implements Task, Configurable { public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to", null, "csv", true); + // The name of the CSV file in which model output (and in the case of classification also votes for individual classes) + // will be saved + public FileOption resultFileOption = new FileOption("resultFile", 'g', "File to append intermediate model output to", + null, "csv", true); + // Default=0: no delay/waiting public IntOption sourceDelayOption = new IntOption("sourceDelay", 'w', "How many microseconds between injections of two instances.", 0, 0, Integer.MAX_VALUE); @@ -167,7 +178,9 @@ public class PrequentialEvaluation implements Task, Configurable { evaluatorOptionValue = getDefaultPerformanceEvaluatorForLearner(classifier); } evaluator = new EvaluatorProcessor.Builder(evaluatorOptionValue) - .samplingFrequency(sampleFrequencyOption.getValue()).dumpFile(dumpFileOption.getFile()).build(); + .samplingFrequency(sampleFrequencyOption.getValue()).dumpFile(dumpFileOption.getFile()) + .predictionFile(resultFileOption.getFile()).labelSamplingFrequency(labelSampleFrequencyOption.getValue()) + .build(); // evaluatorPi = builder.createPi(evaluator); // evaluatorPi.connectInputShuffleStream(evaluatorPiInputStream); http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java ---------------------------------------------------------------------- diff --git a/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java b/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java index 52331c5..f621aba 100644 --- a/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java +++ b/samoa-local/src/test/java/org/apache/samoa/AlgosTest.java @@ -22,6 +22,7 @@ package org.apache.samoa; import org.apache.samoa.LocalDoTask; import org.junit.Test; +import org.apache.samoa.TestParams; public class AlgosTest { @@ -32,6 +33,7 @@ public class AlgosTest { .samplingSize(20_000) .evaluationInstances(200_000) .classifiedInstances(200_000) + .labelSamplingSize(10l) .classificationsCorrect(75f) .kappaStat(0f) .kappaTempStat(0f) @@ -50,6 +52,7 @@ public class AlgosTest { .samplingSize(20_000) .evaluationInstances(200_000) .classifiedInstances(200_000) + .labelSamplingSize(1l) .classificationsCorrect(60f) .kappaStat(0f) .kappaTempStat(0f) @@ -68,6 +71,7 @@ public class AlgosTest { .samplingSize(20_000) .evaluationInstances(200_000) .classifiedInstances(200_000) + .labelSamplingSize(10l) .classificationsCorrect(65f) .kappaStat(0f) .kappaTempStat(0f) @@ -93,6 +97,7 @@ public class AlgosTest { .resultFilePollTimeout(10) .prePollWait(10) .taskClassName(LocalDoTask.class.getName()) + .labelFileCreated(false) .build(); TestUtils.test(vhtConfig); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java ---------------------------------------------------------------------- diff --git a/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java b/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java index d874e51..1c18eaf 100644 --- a/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java +++ b/samoa-storm/src/test/java/org/apache/samoa/AlgosTest.java @@ -35,9 +35,10 @@ public class AlgosTest { .samplingSize(20_000) .evaluationInstances(200_000) .classifiedInstances(200_000) + .labelSamplingSize(10l) .classificationsCorrect(55f) - .kappaStat(0f) - .kappaTempStat(0f) + .kappaStat(-0.1f) + .kappaTempStat(-0.1f) .cliStringTemplate(TestParams.Templates.PREQEVAL_VHT_RANDOMTREE) .resultFilePollTimeout(30) .prePollWait(15) @@ -54,6 +55,7 @@ public class AlgosTest { .samplingSize(20_000) .evaluationInstances(180_000) .classifiedInstances(190_000) + .labelSamplingSize(10l) .classificationsCorrect(60f) .kappaStat(0f) .kappaTempStat(0f) @@ -70,18 +72,19 @@ public class AlgosTest { public void testCVPReqVHTWithStorm() throws Exception { TestParams vhtConfig = new TestParams.Builder() - .inputInstances(200_000) - .samplingSize(20_000) - .evaluationInstances(200_000) - .classifiedInstances(200_000) - .classificationsCorrect(55f) - .kappaStat(0f) - .kappaTempStat(0f) - .cliStringTemplate(TestParams.Templates.PREQCVEVAL_VHT_RANDOMTREE) - .resultFilePollTimeout(30) - .prePollWait(15) - .taskClassName(LocalStormDoTask.class.getName()) - .build(); + .inputInstances(200_000) + .samplingSize(20_000) + .evaluationInstances(200_000) + .classifiedInstances(200_000) + .classificationsCorrect(55f) + .kappaStat(0f) + .kappaTempStat(0f) + .cliStringTemplate(TestParams.Templates.PREQCVEVAL_VHT_RANDOMTREE) + .resultFilePollTimeout(30) + .prePollWait(15) + .taskClassName(LocalStormDoTask.class.getName()) + .labelFileCreated(false) + .build(); TestUtils.test(vhtConfig); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-test/src/test/java/org/apache/samoa/TestParams.java ---------------------------------------------------------------------- diff --git a/samoa-test/src/test/java/org/apache/samoa/TestParams.java b/samoa-test/src/test/java/org/apache/samoa/TestParams.java index b066959..eb7e123 100644 --- a/samoa-test/src/test/java/org/apache/samoa/TestParams.java +++ b/samoa-test/src/test/java/org/apache/samoa/TestParams.java @@ -1,5 +1,7 @@ package org.apache.samoa; +import org.apache.samoa.TestParams.Builder; + /* * #%L * SAMOA @@ -32,20 +34,19 @@ public class TestParams { * </ul> * as well as the maximum number of instances for testing/training (-i) and the sampling size (-f) */ - public static class Templates { - - public final static String PREQEVAL_VHT_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d " + public static class Templates { + public final static String PREQEVAL_VHT_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d -g %s -h %d " + "-l (org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree -p 4) " + "-s (org.apache.samoa.streams.generators.RandomTreeGenerator -c 2 -o 10 -u 10)"; - public final static String PREQEVAL_NAIVEBAYES_HYPERPLANE = "PrequentialEvaluation -d %s -i %d -f %d -w %d " + public final static String PREQEVAL_NAIVEBAYES_HYPERPLANE = "PrequentialEvaluation -d %s -i %d -f %d -w %d -g %s -h %d " + "-l (classifiers.SingleClassifier -l org.apache.samoa.learners.classifiers.NaiveBayes) " + "-s (org.apache.samoa.streams.generators.HyperplaneGenerator -c 2)"; // setting the number of nominal attributes to zero significantly reduces // the processing time, // so that it's acceptable in a test case - public final static String PREQEVAL_BAGGING_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d " + public final static String PREQEVAL_BAGGING_RANDOMTREE = "PrequentialEvaluation -d %s -i %d -f %d -w %d -g %s -h %d " + "-l (org.apache.samoa.learners.classifiers.ensemble.Bagging) " + "-s (org.apache.samoa.streams.generators.RandomTreeGenerator -c 2 -o 0 -u 10)"; @@ -60,6 +61,11 @@ public class TestParams { public static final String CLASSIFICATIONS_CORRECT = "classifications correct (percent)"; public static final String KAPPA_STAT = "Kappa Statistic (percent)"; public static final String KAPPA_TEMP_STAT = "Kappa Temporal Statistic (percent)"; + + public static final String INSTANCE_ID = "instance number"; + public static final String TRUE_CLASS_VALUE = "true class value"; + public static final String PREDICTED_CLASS_VALUE = "predicted class value"; + public static final String VOTES = "votes"; private long inputInstances; private long samplingSize; @@ -73,6 +79,8 @@ public class TestParams { private final int prePollWait; private int inputDelayMicroSec; private String taskClassName; + private boolean labelFileCreated; + private long labelSamplingSize; private TestParams(String taskClassName, long inputInstances, @@ -85,7 +93,9 @@ public class TestParams { String cliStringTemplate, int pollTimeoutSeconds, int prePollWait, - int inputDelayMicroSec) { + int inputDelayMicroSec, + boolean labelFileCreated, + long labelSamplingSize) { this.taskClassName = taskClassName; this.inputInstances = inputInstances; this.samplingSize = samplingSize; @@ -98,6 +108,12 @@ public class TestParams { this.pollTimeoutSeconds = pollTimeoutSeconds; this.prePollWait = prePollWait; this.inputDelayMicroSec = inputDelayMicroSec; + this.labelFileCreated = labelFileCreated; + this.labelSamplingSize = labelSamplingSize; + } + + public boolean getLabelFileCreated() { + return labelFileCreated; } public String getTaskClassName() { @@ -147,6 +163,10 @@ public class TestParams { public int getInputDelayMicroSec() { return inputDelayMicroSec; } + + public long getLabelSamplingSize() { + return labelSamplingSize; + } @Override public String toString() { @@ -163,6 +183,8 @@ public class TestParams { "prePollWait=" + prePollWait + "\n" + "taskClassName='" + taskClassName + '\'' + "\n" + "inputDelayMicroSec=" + inputDelayMicroSec + "\n" + + "labelFileCreated=" + labelFileCreated + "\n" + + "labelSamplingSize=" + labelSamplingSize + "\n" + '}'; } @@ -179,6 +201,8 @@ public class TestParams { private int prePollWaitSeconds = 10; private String taskClassName; private int inputDelayMicroSec = 0; + private boolean labelFileCreated = true; + private long labelSamplingSize = 0l; public Builder taskClassName(String taskClassName) { this.taskClassName = taskClassName; @@ -239,6 +263,16 @@ public class TestParams { this.prePollWaitSeconds = prePollWaitSeconds; return this; } + + public Builder labelFileCreated(boolean labelFileCreated) { + this.labelFileCreated = labelFileCreated; + return this; + } + + public Builder labelSamplingSize(long labelSamplingSize) { + this.labelSamplingSize = labelSamplingSize; + return this; + } public TestParams build() { return new TestParams(taskClassName, @@ -252,7 +286,9 @@ public class TestParams { cliStringTemplate, pollTimeoutSeconds, prePollWaitSeconds, - inputDelayMicroSec); + inputDelayMicroSec, + labelFileCreated, + labelSamplingSize); } } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-test/src/test/java/org/apache/samoa/TestUtils.java ---------------------------------------------------------------------- diff --git a/samoa-test/src/test/java/org/apache/samoa/TestUtils.java b/samoa-test/src/test/java/org/apache/samoa/TestUtils.java index 331f900..b5fef17 100644 --- a/samoa-test/src/test/java/org/apache/samoa/TestUtils.java +++ b/samoa-test/src/test/java/org/apache/samoa/TestUtils.java @@ -49,29 +49,29 @@ public class TestUtils { NoSuchMethodException, InvocationTargetException, IllegalAccessException, InterruptedException { final File tempFile = File.createTempFile("test", "test"); - - LOG.info("Starting test, output file is {}, test config is \n{}", tempFile.getAbsolutePath(), testParams.toString()); - + final File labelFile = File.createTempFile("result", "result"); + LOG.info("Starting test, output file is {}, test config is \n{}", tempFile.getAbsolutePath(), testParams.toString()); Executors.newSingleThreadExecutor().submit(new Callable<Void>() { - @Override public Void call() throws Exception { - try { + try { Class.forName(testParams.getTaskClassName()) - .getMethod("main", String[].class) - .invoke(null, (Object) String.format( - testParams.getCliStringTemplate(), - tempFile.getAbsolutePath(), - testParams.getInputInstances(), - testParams.getSamplingSize(), - testParams.getInputDelayMicroSec() - ).split("[ ]")); - } catch (Exception e) { - LOG.error("Cannot execute test {} {}", e.getMessage(), e.getCause().getMessage()); + .getMethod("main", String[].class) + .invoke(null, (Object) String.format( + testParams.getCliStringTemplate(), + tempFile.getAbsolutePath(), + testParams.getInputInstances(), + testParams.getSamplingSize(), + testParams.getInputDelayMicroSec(), + labelFile.getAbsolutePath(), + testParams.getLabelSamplingSize() + ).split("[ ]")); + } catch (Exception e) { + LOG.error("Cannot execute test {} {}", e.getMessage(), e.getCause().getMessage()); + } + return null; } - return null; - } - }); + }); Thread.sleep(TimeUnit.SECONDS.toMillis(testParams.getPrePollWaitSeconds())); @@ -89,6 +89,8 @@ public class TestUtils { tailer.stop(); assertResults(tempFile, testParams); + if (testParams.getLabelFileCreated()) + assertLabels(labelFile, testParams); } public static void assertResults(File outputFile, org.apache.samoa.TestParams testParams) throws IOException { @@ -136,6 +138,32 @@ public class TestUtils { testParams.getKappaTempStat() <= Float.parseFloat(last.get(4 + 3 * cvEvaluation))); } + + public static void assertLabels(File labelFile, org.apache.samoa.TestParams testParams) throws IOException { + LOG.info("Checking labels file " + labelFile.getAbsolutePath()); + //1. parse result file with csv parser + Reader in = new FileReader(labelFile); + long lineCount = 0; + long expectedLineCount = testParams.getInputInstances() / testParams.getLabelSamplingSize(); + Iterable<CSVRecord> records = CSVFormat.EXCEL.withSkipHeaderRecord(false) + .withIgnoreEmptyLines(true).withDelimiter(',').withCommentMarker('#').parse(in); + + Iterator<CSVRecord> iterator = records.iterator(); + CSVRecord header = iterator.next(); + + while (iterator.hasNext()) { + iterator.next(); + lineCount = lineCount + 1; + } + + Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.INSTANCE_ID, header.get(0).trim()); + Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.TRUE_CLASS_VALUE, header.get(1).trim()); + Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.PREDICTED_CLASS_VALUE, header.get(2).trim()); + for (int i = 3; i < header.size(); i++) + Assert.assertEquals("Unexpected column", org.apache.samoa.TestParams.VOTES, header.get(i).trim().substring(0, org.apache.samoa.TestParams.VOTES.length())); + Assert.assertEquals("Wrong number of lines in prediction file", expectedLineCount, lineCount); + + } private static class TestResultsTailerAdapter extends TailerListenerAdapter { http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/dbc3aab1/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java ---------------------------------------------------------------------- diff --git a/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java b/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java index 031d98d..f43d667 100644 --- a/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java +++ b/samoa-threads/src/test/java/org/apache/samoa/AlgosTest.java @@ -35,6 +35,7 @@ public class AlgosTest { .samplingSize(20_000) .evaluationInstances(200_000) .classifiedInstances(200_000) + .labelSamplingSize(10l) .classificationsCorrect(55f) .kappaStat(-0.1f) .kappaTempStat(-0.1f) @@ -55,6 +56,7 @@ public class AlgosTest { .inputDelayMicroSec(100) // prevents saturating the system due to unbounded queues .evaluationInstances(90_000) .classifiedInstances(100_000) + .labelSamplingSize(10l) .classificationsCorrect(55f) .kappaStat(0f) .kappaTempStat(0f)
