SAMOA-34: Fix Bagging
Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/87d37322 Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/87d37322 Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/87d37322 Branch: refs/heads/master Commit: 87d373221864617e336b94ec873d8beaf7bc88eb Parents: 31f82ec Author: Gianmarco De Francisci Morales <[email protected]> Authored: Thu Jun 18 16:08:34 2015 +0300 Committer: abifet <[email protected]> Committed: Wed Jul 1 18:07:04 2015 +0800 ---------------------------------------------------------------------- .../learners/classifiers/ensemble/Bagging.java | 61 ++++---- .../ensemble/BaggingDistributorProcessor.java | 140 ++++++------------- .../ensemble/PredictionCombinerProcessor.java | 10 +- .../trees/VerticalHoeffdingTree.java | 5 +- 4 files changed, 83 insertions(+), 133 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/87d37322/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java index 7355b1a..43bc07c 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java @@ -24,8 +24,6 @@ package org.apache.samoa.learners.classifiers.ensemble; * License */ -import com.google.common.collect.ImmutableSet; - import java.util.Set; import org.apache.samoa.core.Processor; @@ -34,10 +32,13 @@ import org.apache.samoa.learners.Learner; import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; import org.apache.samoa.topology.Stream; import org.apache.samoa.topology.TopologyBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import com.github.javacliparser.ClassOption; import com.github.javacliparser.Configurable; import com.github.javacliparser.IntOption; +import com.google.common.collect.ImmutableSet; /** * The Bagging Classifier by Oza and Russell. @@ -46,6 +47,7 @@ public class Bagging implements Learner, Configurable { /** The Constant serialVersionUID. */ private static final long serialVersionUID = -2971850264864952099L; + private static final Logger logger = LoggerFactory.getLogger(Bagging.class); /** The base learner option. */ public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', @@ -58,11 +60,8 @@ public class Bagging implements Learner, Configurable { /** The distributor processor. */ private BaggingDistributorProcessor distributorP; - /** The training stream. */ - private Stream testingStream; - - /** The prediction stream. */ - private Stream predictionStream; + /** The input streams for the ensemble, one per member. */ + private Stream[] ensembleStreams; /** The result stream. */ protected Stream resultStream; @@ -70,45 +69,57 @@ public class Bagging implements Learner, Configurable { /** The dataset. */ private Instances dataset; - protected Learner classifier; + protected Learner[] ensemble; protected int parallelism; /** * Sets the layout. + * + * @throws Exception */ protected void setLayout() { - - int sizeEnsemble = this.ensembleSizeOption.getValue(); + int ensembleSize = this.ensembleSizeOption.getValue(); distributorP = new BaggingDistributorProcessor(); - distributorP.setSizeEnsemble(sizeEnsemble); - this.builder.addProcessor(distributorP, 1); + distributorP.setEnsembleSize(ensembleSize); + builder.addProcessor(distributorP, 1); // instantiate classifier - classifier = (Learner) this.baseLearnerOption.getValue(); - classifier.init(builder, this.dataset, sizeEnsemble); + ensemble = new Learner[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + try { + ensemble[i] = (Learner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(), + baseLearnerOption.getRequiredType()); + } catch (Exception e) { + logger.error("Unable to create members of the ensemble. Please check your CLI parameters"); + e.printStackTrace(); + throw new IllegalArgumentException(e); + } + ensemble[i].init(builder, this.dataset, 1); // sequential + } PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor(); - predictionCombinerP.setSizeEnsemble(sizeEnsemble); + predictionCombinerP.setEnsembleSize(ensembleSize); this.builder.addProcessor(predictionCombinerP, 1); // Streams - resultStream = this.builder.createStream(predictionCombinerP); + resultStream = builder.createStream(predictionCombinerP); predictionCombinerP.setOutputStream(resultStream); - for (Stream subResultStream : classifier.getResultStreams()) { - this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); + for (Learner member : ensemble) { + for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); // the key is the instance id to combine predictions + } } - testingStream = this.builder.createStream(distributorP); - this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor()); - - predictionStream = this.builder.createStream(distributorP); - this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor()); + ensembleStreams = new Stream[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + ensembleStreams[i] = builder.createStream(distributorP); + builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter) + } - distributorP.setOutputStream(testingStream); - distributorP.setPredictionStream(predictionStream); + distributorP.setOutputStreams(ensembleStreams); } /** The builder. */ http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/87d37322/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java index 33615db..6c88d94 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java @@ -24,6 +24,7 @@ package org.apache.samoa.learners.classifiers.ensemble; * License */ +import java.util.Arrays; import java.util.Random; import org.apache.samoa.core.ContentEvent; @@ -38,19 +39,16 @@ import org.apache.samoa.topology.Stream; */ public class BaggingDistributorProcessor implements Processor { - /** - * - */ private static final long serialVersionUID = -1550901409625192730L; - /** The size ensemble. */ - private int sizeEnsemble; + /** The ensemble size. */ + private int ensembleSize; - /** The training stream. */ - private Stream trainingStream; + /** The stream ensemble. */ + private Stream[] ensembleStreams; - /** The prediction stream. */ - private Stream predictionStream; + /** Ramdom number generator. */ + protected Random random = new Random(); //TODO make random seed configurable /** * On event. @@ -60,38 +58,34 @@ public class BaggingDistributorProcessor implements Processor { * @return true, if successful */ public boolean process(ContentEvent event) { - InstanceContentEvent inEvent = (InstanceContentEvent) event; // ((s4Event)event).getContentEvent(); - // InstanceEvent inEvent = (InstanceEvent) event; + InstanceContentEvent inEvent = (InstanceContentEvent) event; if (inEvent.getInstanceIndex() < 0) { - // End learning - predictionStream.put(event); + // end learning + for (Stream stream : ensembleStreams) + stream.put(event); return false; } if (inEvent.isTesting()) { - Instance trainInst = inEvent.getInstance(); - for (int i = 0; i < sizeEnsemble; i++) { - Instance weightedInst = trainInst.copy(); - // weightedInst.setWeight(trainInst.weight() * k); - InstanceContentEvent instanceContentEvent = new InstanceContentEvent( - inEvent.getInstanceIndex(), weightedInst, false, true); - instanceContentEvent.setClassifierIndex(i); - instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); - predictionStream.put(instanceContentEvent); + Instance testInstance = inEvent.getInstance(); + for (int i = 0; i < ensembleSize; i++) { + Instance instanceCopy = testInstance.copy(); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy, + false, true); + instanceContentEvent.setClassifierIndex(i); //TODO probably not needed anymore + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); //TODO probably not needed anymore + ensembleStreams[i].put(instanceContentEvent); } } - /* Estimate model parameters using the training data. */ + // estimate model parameters using the training data if (inEvent.isTraining()) { train(inEvent); } - return false; + return true; } - /** The random. */ - protected Random random = new Random(); - /** * Train. * @@ -99,104 +93,51 @@ public class BaggingDistributorProcessor implements Processor { * the in event */ protected void train(InstanceContentEvent inEvent) { - Instance trainInst = inEvent.getInstance(); - for (int i = 0; i < sizeEnsemble; i++) { + Instance trainInstance = inEvent.getInstance(); + for (int i = 0; i < ensembleSize; i++) { int k = MiscUtils.poisson(1.0, this.random); if (k > 0) { - Instance weightedInst = trainInst.copy(); - weightedInst.setWeight(trainInst.weight() * k); - InstanceContentEvent instanceContentEvent = new InstanceContentEvent( - inEvent.getInstanceIndex(), weightedInst, true, false); + Instance weightedInstance = trainInstance.copy(); + weightedInstance.setWeight(trainInstance.weight() * k); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), + weightedInstance, true, false); instanceContentEvent.setClassifierIndex(i); instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); - trainingStream.put(instanceContentEvent); + ensembleStreams[i].put(instanceContentEvent); } } } - /* - * (non-Javadoc) - * - * @see org.apache.s4.core.ProcessingElement#onCreate() - */ @Override public void onCreate(int id) { // do nothing } - /** - * Gets the training stream. - * - * @return the training stream - */ - public Stream getTrainingStream() { - return trainingStream; - } - - /** - * Sets the training stream. - * - * @param trainingStream - * the new training stream - */ - public void setOutputStream(Stream trainingStream) { - this.trainingStream = trainingStream; - } - - /** - * Gets the prediction stream. - * - * @return the prediction stream - */ - public Stream getPredictionStream() { - return predictionStream; + public Stream[] getOutputStreams() { + return ensembleStreams; } - /** - * Sets the prediction stream. - * - * @param predictionStream - * the new prediction stream - */ - public void setPredictionStream(Stream predictionStream) { - this.predictionStream = predictionStream; + public void setOutputStreams(Stream[] ensembleStreams) { + this.ensembleStreams = ensembleStreams; } - /** - * Gets the size ensemble. - * - * @return the size ensemble - */ - public int getSizeEnsemble() { - return sizeEnsemble; + public int getEnsembleSize() { + return ensembleSize; } - /** - * Sets the size ensemble. - * - * @param sizeEnsemble - * the new size ensemble - */ - public void setSizeEnsemble(int sizeEnsemble) { - this.sizeEnsemble = sizeEnsemble; + public void setEnsembleSize(int ensembleSize) { + this.ensembleSize = ensembleSize; } - /* - * (non-Javadoc) - * - * @see samoa.core.Processor#newProcessor(samoa.core.Processor) - */ @Override public Processor newProcessor(Processor sourceProcessor) { BaggingDistributorProcessor newProcessor = new BaggingDistributorProcessor(); BaggingDistributorProcessor originProcessor = (BaggingDistributorProcessor) sourceProcessor; - if (originProcessor.getPredictionStream() != null) { - newProcessor.setPredictionStream(originProcessor.getPredictionStream()); + if (originProcessor.getOutputStreams() != null) { + newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(), + originProcessor.getOutputStreams().length)); } - if (originProcessor.getTrainingStream() != null) { - newProcessor.setOutputStream(originProcessor.getTrainingStream()); - } - newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble()); + newProcessor.setEnsembleSize(originProcessor.getEnsembleSize()); /* * if (originProcessor.getLearningCurve() != null){ * newProcessor.setLearningCurve((LearningCurve) @@ -204,5 +145,4 @@ public class BaggingDistributorProcessor implements Processor { */ return newProcessor; } - } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/87d37322/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java index 2e5f335..76e84f8 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java @@ -33,14 +33,14 @@ import org.apache.samoa.moa.core.DoubleVector; import org.apache.samoa.topology.Stream; /** - * The Class PredictionCombinerProcessor. + * Combines predictions coming from an ensemble. Equivalent to a majority-vote classifier. */ public class PredictionCombinerProcessor implements Processor { private static final long serialVersionUID = -1606045723451191132L; /** - * The size ensemble. + * The ensemble size. */ protected int ensembleSize; @@ -73,7 +73,7 @@ public class PredictionCombinerProcessor implements Processor { * * @return the ensembleSize */ - public int getSizeEnsemble() { + public int getEnsembleSize() { return ensembleSize; } @@ -83,7 +83,7 @@ public class PredictionCombinerProcessor implements Processor { * @param ensembleSize * the new size ensemble */ - public void setSizeEnsemble(int ensembleSize) { + public void setEnsembleSize(int ensembleSize) { this.ensembleSize = ensembleSize; } @@ -143,7 +143,7 @@ public class PredictionCombinerProcessor implements Processor { if (originProcessor.getOutputStream() != null) { newProcessor.setOutputStream(originProcessor.getOutputStream()); } - newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble()); + newProcessor.setEnsembleSize(originProcessor.getEnsembleSize()); return newProcessor; } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/87d37322/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java index ea7e53d..6534cee 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java @@ -20,8 +20,6 @@ package org.apache.samoa.learners.classifiers.trees; * #L% */ -import com.google.common.collect.ImmutableSet; - import java.util.Set; import org.apache.samoa.core.Processor; @@ -41,6 +39,7 @@ import com.github.javacliparser.Configurable; import com.github.javacliparser.FlagOption; import com.github.javacliparser.FloatOption; import com.github.javacliparser.IntOption; +import com.google.common.collect.ImmutableSet; /** * Vertical Hoeffding Tree. @@ -172,7 +171,7 @@ public final class VerticalHoeffdingTree implements ClassificationLearner, Adapt public void setChangeDetector(ChangeDetector cd) { this.changeDetector = cd; } - + static class LearningNodeIdGenerator { // TODO: add code to warn user of when value reaches Long.MAX_VALUES
