SAMOA-53: Add Prequential Cross-Validation Evaluation
Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/7a177650 Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/7a177650 Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/7a177650 Branch: refs/heads/master Commit: 7a177650f54096457b339d797d6ab3d4b558856b Parents: 1bd1012 Author: Albert Bifet <[email protected]> Authored: Mon Dec 7 15:00:09 2015 +0100 Committer: Albert Bifet <[email protected]> Committed: Tue Mar 15 20:39:25 2016 +0100 ---------------------------------------------------------------------- .../EvaluationDistributorProcessor.java | 168 +++++++++++ .../samoa/evaluation/EvaluatorCVProcessor.java | 288 +++++++++++++++++++ .../samoa/tasks/PrequentialCVEvaluation.java | 166 +++++++++++ .../samoa/tasks/PrequentialEvaluation.java | 12 +- 4 files changed, 628 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/7a177650/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluationDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluationDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluationDistributorProcessor.java new file mode 100644 index 0000000..a243a10 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluationDistributorProcessor.java @@ -0,0 +1,168 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.base.Preconditions; +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.moa.core.MiscUtils; +import org.apache.samoa.topology.Stream; + +import java.util.Arrays; +import java.util.Random; + +/** + * The Class EvaluationDistributorProcessor. + */ +public class EvaluationDistributorProcessor implements Processor { + + private static final long serialVersionUID = -1550901409625192734L; + + /** The ensemble size or number of folds. */ + private int numberClassifiers; + + /** The stream ensemble. */ + private Stream[] ensembleStreams; + + /** Random number generator. */ + protected Random random = new Random(); + + /** Random seed */ + protected int randomSeed; + + /** The methodology to use to perform the validation */ + public int validationMethodology; + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + Preconditions.checkState(numberClassifiers == ensembleStreams.length, String.format( + "Ensemble size ({}) and number of ensemble streams ({}) do not match.", numberClassifiers, ensembleStreams.length)); + InstanceContentEvent inEvent = (InstanceContentEvent) event; + + if (inEvent.getInstanceIndex() < 0) { + // end learning + for (Stream stream : ensembleStreams) + stream.put(event); + return false; + } + + if (inEvent.isTesting()) { + Instance testInstance = inEvent.getInstance(); + for (int i = 0; i < numberClassifiers; i++) { + Instance instanceCopy = testInstance.copy(); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy, + false, true); + instanceContentEvent.setEvaluationIndex(i); //TODO probably not needed anymore + ensembleStreams[i].put(instanceContentEvent); + } + } + + // estimate model parameters using the training data + if (inEvent.isTraining()) { + train(inEvent); + } + return true; + } + + /** + * Train. + * + * @param inEvent + * the in event + */ + protected void train(InstanceContentEvent inEvent) { + Instance trainInstance = inEvent.getInstance(); + long instancesProcessed = inEvent.getInstanceIndex(); + for (int i = 0; i < numberClassifiers; i++) { + int k = 1; + switch (this.validationMethodology) { + case 0: //Cross-Validation; + k = instancesProcessed % numberClassifiers == i ? 0 : 1; //Test all except one + break; + case 1: //Bootstrap; + k = MiscUtils.poisson(1, this.random); + break; + case 2: //Split-Validation; + k = instancesProcessed % numberClassifiers == i ? 1 : 0; //Test only one + break; + } + if (k > 0) { + Instance weightedInstance = trainInstance.copy(); + weightedInstance.setWeight(trainInstance.weight() * k); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), + weightedInstance, true, false); + instanceContentEvent.setEvaluationIndex(i); + ensembleStreams[i].put(instanceContentEvent); + } + } + } + + @Override + public void onCreate(int id) { + // do nothing + } + + public Stream[] getOutputStreams() { + return ensembleStreams; + } + + public void setOutputStreams(Stream[] ensembleStreams) { + this.ensembleStreams = ensembleStreams; + } + + public int getNumberClassifiers() { + return numberClassifiers; + } + + public void setNumberClassifiers(int numberClassifiers) { + this.numberClassifiers = numberClassifiers; + } + + public void setValidationMethodologyOption(int index) { this.validationMethodology = index;} + + public void setRandomSeed(int seed){this.randomSeed = seed; this.random = new Random(seed);} + + @Override + public Processor newProcessor(Processor sourceProcessor) { + EvaluationDistributorProcessor newProcessor = new EvaluationDistributorProcessor(); + EvaluationDistributorProcessor originProcessor = (EvaluationDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStreams() != null) { + newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(), + originProcessor.getOutputStreams().length)); + } + newProcessor.setNumberClassifiers(originProcessor.getNumberClassifiers()); + newProcessor.setValidationMethodologyOption(originProcessor.validationMethodology); + newProcessor.setRandomSeed(originProcessor.randomSeed); + return newProcessor; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/7a177650/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java new file mode 100644 index 0000000..f282f0d --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorCVProcessor.java @@ -0,0 +1,288 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.evaluation.LearningCurve; +import org.apache.samoa.moa.evaluation.LearningEvaluation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.PrintStream; +import java.util.*; +import java.util.concurrent.TimeUnit; + +public class EvaluatorCVProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -2778051819116753612L; + + private static final Logger logger = + LoggerFactory.getLogger(EvaluatorCVProcessor.class); + + private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances"; + + private final PerformanceEvaluator[] evaluators; + private final int samplingFrequency; + private final File dumpFile; + private transient PrintStream immediateResultStream = null; + private transient boolean firstDump = true; + + private long totalCount = 0; + private long experimentStart = 0; + + private long sampleStart = 0; + + private LearningCurve learningCurve; + private int id; + + private int foldNumber = 10; + + private EvaluatorCVProcessor(Builder builder) { + evaluators = new PerformanceEvaluator[builder.foldNumber]; + for (int i = 0; i < this.evaluators.length; i++) { + evaluators[i] = (PerformanceEvaluator) builder.evaluator.copy(); + } + this.samplingFrequency = builder.samplingFrequency; + this.dumpFile = builder.dumpFile; + this.foldNumber = builder.foldNumber; + } + + private boolean initiated = false; + + @Override + public boolean process(ContentEvent event) { + + if (this.initiated == false) { + sampleStart = System.nanoTime(); + experimentStart = sampleStart; + this.initiated = true; + } + + ResultContentEvent result = (ResultContentEvent) event; + int instanceIndex = (int) result.getInstanceIndex(); + + addStatisticsForInstanceReceived(instanceIndex, result.getEvaluationIndex(), 1); + + evaluators[result.getEvaluationIndex()].addResult(result.getInstance(), result.getClassVotes()); + + if (hasAllVotesArrivedInstance(instanceIndex)) { + totalCount += 1; + if (result.isLastEvent()) { + this.concludeMeasurement(); + return true; + } + //this.mapCountsforInstanceReceived.remove(instanceIndex); + + if ((totalCount > 0) && (totalCount % samplingFrequency) == 0) { + long sampleEnd = System.nanoTime(); + long sampleDuration = TimeUnit.SECONDS.convert(sampleEnd - sampleStart, TimeUnit.NANOSECONDS); + sampleStart = sampleEnd; + + logger.info("{} seconds for {} instances", sampleDuration, samplingFrequency); + this.addMeasurement(); + } + } + + + + return false; + } + + protected Map<Integer, Integer> mapCountsforInstanceReceived; + + private boolean hasAllVotesArrivedInstance(int instanceIndex) { + Map<Integer, Integer> map = this.mapCountsforInstanceReceived; + int count = map.get(instanceIndex); + return (count == this.foldNumber); + } + protected void addStatisticsForInstanceReceived(int instanceIndex, int evaluationIndex, int add) { + if (this.mapCountsforInstanceReceived == null) { + this.mapCountsforInstanceReceived = new HashMap<>(); + } + Integer count = this.mapCountsforInstanceReceived.get(instanceIndex); + if (count == null) { + count = 0; + } + this.mapCountsforInstanceReceived.put(instanceIndex, count + add); + } + + @Override + public void onCreate(int id) { + this.id = id; + this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME); + + if (this.dumpFile != null) { + try { + if (dumpFile.exists()) { + this.immediateResultStream = new PrintStream( + new FileOutputStream(dumpFile, true), true); + } else { + this.immediateResultStream = new PrintStream( + new FileOutputStream(dumpFile), true); + } + + } catch (FileNotFoundException e) { + this.immediateResultStream = null; + logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); + + } catch (Exception e) { + this.immediateResultStream = null; + logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); + } + } + + this.firstDump = true; + } + + @Override + public Processor newProcessor(Processor p) { + EvaluatorCVProcessor originalProcessor = (EvaluatorCVProcessor) p; + EvaluatorCVProcessor newProcessor = new EvaluatorCVProcessor.Builder(originalProcessor).build(); + + if (originalProcessor.learningCurve != null) { + newProcessor.learningCurve = originalProcessor.learningCurve; + } + + return newProcessor; + } + + @Override + public String toString() { + StringBuilder report = new StringBuilder(); + + report.append(EvaluatorCVProcessor.class.getCanonicalName()); + report.append("id = ").append(this.id); + report.append('\n'); + + if (learningCurve.numEntries() > 0) { + report.append(learningCurve.toString()); + report.append('\n'); + } + return report.toString(); + } + + private void addMeasurement() { + List<Measurement> measurements = new Vector<>(); + measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount )); + + Measurement[] finalMeasurements = getEvaluationMeasurements( + measurements.toArray(new Measurement[measurements.size()]), evaluators); + + LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements); + learningCurve.insertEntry(learningEvaluation); + logger.debug("evaluator id = {}", this.id); + logger.info(learningEvaluation.toString()); + + if (immediateResultStream != null) { + if (firstDump) { + immediateResultStream.println(learningCurve.headerToString()); + firstDump = false; + } + + immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1)); + immediateResultStream.flush(); + } + } + + private void concludeMeasurement() { + logger.info("last event is received!"); + logger.info("total count: {}", this.totalCount); + + String learningCurveSummary = this.toString(); + logger.info(learningCurveSummary); + + long experimentEnd = System.nanoTime(); + long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS); + logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount ); + + if (immediateResultStream != null) { + immediateResultStream.println("# COMPLETED"); + immediateResultStream.flush(); + } + // logger.info("average throughput rate: {} instances/seconds", + // (totalCount/totalExperimentTime)); + } + + public static class Builder { + + private final PerformanceEvaluator evaluator; + private int samplingFrequency = 100000; + private File dumpFile = null; + private int foldNumber = 10; + + public Builder(PerformanceEvaluator evaluator) { + this.evaluator = evaluator; + } + + public Builder(EvaluatorCVProcessor oldProcessor) { + this.evaluator = oldProcessor.evaluators[0]; + this.samplingFrequency = oldProcessor.samplingFrequency; + this.dumpFile = oldProcessor.dumpFile; + } + + public Builder samplingFrequency(int samplingFrequency) { + this.samplingFrequency = samplingFrequency; + return this; + } + + public Builder dumpFile(File file) { + this.dumpFile = file; + return this; + } + + public Builder foldNumber(int foldNumber){ + this.foldNumber = foldNumber; + return this; + } + + public EvaluatorCVProcessor build() { + return new EvaluatorCVProcessor(this); + } + } + + public Measurement[] getEvaluationMeasurements(Measurement[] modelMeasurements, PerformanceEvaluator[] subEvaluators) { + List<Measurement> measurementList = new LinkedList<Measurement>(); + if (modelMeasurements != null) { + measurementList.addAll(Arrays.asList(modelMeasurements)); + } + // add average of sub-model measurements + if ((subEvaluators != null) && (subEvaluators.length > 0)) { + List<Measurement[]> subMeasurements = new LinkedList<Measurement[]>(); + for (PerformanceEvaluator subEvaluator : subEvaluators) { + if (subEvaluator != null) { + subMeasurements.add(subEvaluator.getPerformanceMeasurements()); + } + } + Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][])); + measurementList.addAll(Arrays.asList(avgMeasurements)); + } + return measurementList.toArray(new Measurement[measurementList.size()]); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/7a177650/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialCVEvaluation.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialCVEvaluation.java b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialCVEvaluation.java new file mode 100644 index 0000000..5794e8c --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialCVEvaluation.java @@ -0,0 +1,166 @@ +package org.apache.samoa.tasks; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.github.javacliparser.*; +import org.apache.samoa.evaluation.*; +import org.apache.samoa.learners.ClassificationLearner; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.RegressionLearner; +import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; +import org.apache.samoa.moa.streams.InstanceStream; +import org.apache.samoa.moa.streams.generators.RandomTreeGenerator; +import org.apache.samoa.streams.PrequentialSourceProcessor; +import org.apache.samoa.topology.ComponentFactory; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.Topology; +import org.apache.samoa.topology.TopologyBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.text.SimpleDateFormat; +import java.util.Date; + +/** + * Prequential Cross-Validation Evaluation. Evaluation that evaluates performance of online + * classifiers using prequential cross-validation: each instance is used for testing online + * classifiers model and then the same instance is used for training the model using one of + * these strategies: k-fold distributed Cross Validation, k-fold distributed Bootstrap Validation, + * or k-fold distributed Split Validation. + * + * More information in: Albert Bifet, Gianmarco De Francisci Morales, Jesse Read, Geoff Holmes, + * Bernhard Pfahringer: Efficient Online Evaluation of Big Data Stream Classifiers. KDD 2015: 59-68 + * + */ +public class PrequentialCVEvaluation extends PrequentialEvaluation { + + public IntOption foldNumberOption = new IntOption("foldNumber", 'x', + "The number of distributed models.", 10, 1, Integer.MAX_VALUE); + + public MultiChoiceOption validationMethodologyOption = new MultiChoiceOption( + "validationMethodology", 'a', "Validation methodology to use.", new String[]{ + "Cross-Validation", "Bootstrap-Validation", "Split-Validation"}, + new String[]{"k-fold distributed Cross Validation", + "k-fold distributed Bootstrap Validation", + "k-fold distributed Split Validation" + }, 0); + + public IntOption randomSeedOption = new IntOption("randomSeed", 'r', + "Seed for random behaviour of the task.", 1); + + public void getDescription(StringBuilder sb, int indent) { + sb.append("Prequential CV evaluation"); + } + + /** The distributor processor. */ + private EvaluationDistributorProcessor distributorP; + + private Stream[] ensembleStream; + + protected Learner[] ensemble; + + private EvaluatorCVProcessor evaluator; + + private static Logger logger = LoggerFactory.getLogger(PrequentialCVEvaluation.class); + + @Override + public void init() { + // TODO remove the if statement + // theoretically, dynamic binding will work here! + // test later! + // for now, the if statement is used by Storm + + if (builder == null) { + builder = new TopologyBuilder(); + logger.debug("Successfully instantiating TopologyBuilder"); + + builder.initTopology(evaluationNameOption.getValue()); + logger.debug("Successfully initializing SAMOA topology with name {}", evaluationNameOption.getValue()); + } + + // instantiate PrequentialSourceProcessor and its output stream + // (sourcePiOutputStream) + preqSource = new PrequentialSourceProcessor(); + preqSource.setStreamSource((InstanceStream) this.streamTrainOption.getValue()); + preqSource.setMaxNumInstances(instanceLimitOption.getValue()); + preqSource.setSourceDelay(sourceDelayOption.getValue()); + preqSource.setDelayBatchSize(batchDelayOption.getValue()); + builder.addEntranceProcessor(preqSource); + logger.debug("Successfully instantiating PrequentialSourceProcessor"); + + sourcePiOutputStream = builder.createStream(preqSource); + + //Add EvaluationDistributorProcessor + int numberFolds = this.foldNumberOption.getValue(); + distributorP = new EvaluationDistributorProcessor(); + distributorP.setNumberClassifiers(numberFolds); + distributorP.setValidationMethodologyOption(this.validationMethodologyOption.getChosenIndex()); + distributorP.setRandomSeed(this.randomSeedOption.getValue()); + builder.addProcessor(distributorP, 1); + builder.connectInputAllStream(sourcePiOutputStream, distributorP); + + // instantiate classifier + int foldNumber = this.foldNumberOption.getValue(); + ensemble = new Learner[foldNumber]; + for (int i = 0; i < foldNumber; i++) { + try { + ensemble[i] = (Learner) ClassOption.createObject(learnerOption.getValueAsCLIString(), + learnerOption.getRequiredType()); + } catch (Exception e) { + logger.error("Unable to create classifiers for the distributed evaluation. Please check your CLI parameters"); + e.printStackTrace(); + throw new IllegalArgumentException(e); + } + ensemble[i].init(builder, preqSource.getDataset(), 1); // sequential + } + logger.debug("Successfully instantiating Classifiers"); + + Stream[] ensembleStreams = new Stream[foldNumber]; + for (int i = 0; i < foldNumber; i++) { + ensembleStreams[i] = builder.createStream(distributorP); + builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter) + } + distributorP.setOutputStreams(ensembleStreams); + + PerformanceEvaluator evaluatorOptionValue = this.evaluatorOption.getValue(); + if (!PrequentialCVEvaluation.isLearnerAndEvaluatorCompatible(ensemble[0], evaluatorOptionValue)) { + evaluatorOptionValue = getDefaultPerformanceEvaluatorForLearner(ensemble[0]); + } + evaluator = new EvaluatorCVProcessor.Builder(evaluatorOptionValue) + .samplingFrequency(sampleFrequencyOption.getValue()) + .dumpFile(dumpFileOption.getFile()) + .foldNumber(numberFolds).build(); + + builder.addProcessor(evaluator, 1); + + for (Learner member : ensemble) { + for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams + this.builder.connectInputKeyStream(subResultStream, evaluator); // the key is the instance id to combine predictions + } + } + + logger.debug("Successfully instantiating EvaluatorProcessor"); + + prequentialTopology = builder.build(); + logger.debug("Successfully building the topology"); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/7a177650/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java index da0057d..001622b 100644 --- a/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java +++ b/samoa-api/src/main/java/org/apache/samoa/tasks/PrequentialEvaluation.java @@ -99,13 +99,13 @@ public class PrequentialEvaluation implements Task, Configurable { public IntOption batchDelayOption = new IntOption("delayBatchSize", 'b', "The delay batch size: delay of x milliseconds after each batch ", 1, 1, Integer.MAX_VALUE); - private PrequentialSourceProcessor preqSource; + protected PrequentialSourceProcessor preqSource; // private PrequentialSourceTopologyStarter preqStarter; // private EntranceProcessingItem sourcePi; - private Stream sourcePiOutputStream; + protected Stream sourcePiOutputStream; private Learner classifier; @@ -115,9 +115,9 @@ public class PrequentialEvaluation implements Task, Configurable { // private Stream evaluatorPiInputStream; - private Topology prequentialTopology; + protected Topology prequentialTopology; - private TopologyBuilder builder; + protected TopologyBuilder builder; public void getDescription(StringBuilder sb, int indent) { sb.append("Prequential evaluation"); @@ -205,12 +205,12 @@ public class PrequentialEvaluation implements Task, Configurable { // return this.preqStarter; // } - private static boolean isLearnerAndEvaluatorCompatible(Learner learner, PerformanceEvaluator evaluator) { + protected static boolean isLearnerAndEvaluatorCompatible(Learner learner, PerformanceEvaluator evaluator) { return (learner instanceof RegressionLearner && evaluator instanceof RegressionPerformanceEvaluator) || (learner instanceof ClassificationLearner && evaluator instanceof ClassificationPerformanceEvaluator); } - private static PerformanceEvaluator getDefaultPerformanceEvaluatorForLearner(Learner learner) { + protected static PerformanceEvaluator getDefaultPerformanceEvaluatorForLearner(Learner learner) { if (learner instanceof RegressionLearner) { return new BasicRegressionPerformanceEvaluator(); }
