SAMOA-34: Fix AdaptiveBagging
Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/fc44099a Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/fc44099a Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/fc44099a Branch: refs/heads/master Commit: fc44099a35508f1638b893f80450915c77305836 Parents: 87d3732 Author: Gianmarco De Francisci Morales <[email protected]> Authored: Thu Jun 18 16:20:46 2015 +0300 Committer: abifet <[email protected]> Committed: Wed Jul 1 18:07:04 2015 +0800 ---------------------------------------------------------------------- .../apache/samoa/learners/AdaptiveLearner.java | 3 +- .../classifiers/ensemble/AdaptiveBagging.java | 68 ++++++++++---------- .../learners/classifiers/ensemble/Bagging.java | 3 - 3 files changed, 35 insertions(+), 39 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/fc44099a/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java b/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java index 28d0059..54af7b6 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/AdaptiveLearner.java @@ -25,14 +25,13 @@ package org.apache.samoa.learners; */ import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; -import org.apache.samoa.topology.Stream; /** * The Interface Adaptive Learner. Initializing Classifier should initalize PI to connect the Classifier with the input * stream and initialize result stream so that other PI can connect to the classification result of this classifier */ -public interface AdaptiveLearner { +public interface AdaptiveLearner extends Learner{ /** * Gets the change detector item. http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/fc44099a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java index 9ffba2a..4b2c531 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java @@ -45,19 +45,16 @@ import com.github.javacliparser.Configurable; import com.github.javacliparser.IntOption; /** - * The Bagging Classifier by Oza and Russell. + * An adaptive version of the Bagging Classifier by Oza and Russell. */ public class AdaptiveBagging implements Learner, Configurable { - /** Logger */ + private static final long serialVersionUID = 8217274236558839040L; private static final Logger logger = LoggerFactory.getLogger(AdaptiveBagging.class); - /** The Constant serialVersionUID. */ - private static final long serialVersionUID = -2971850264864952099L; - /** The base learner option. */ public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', - "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName()); + "Classifier to train.", AdaptiveLearner.class, VerticalHoeffdingTree.class.getName()); /** The ensemble size option. */ public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', @@ -69,59 +66,63 @@ public class AdaptiveBagging implements Learner, Configurable { /** The distributor processor. */ private BaggingDistributorProcessor distributorP; + /** The input streams for the ensemble, one per member. */ + private Stream[] ensembleStreams; + /** The result stream. */ protected Stream resultStream; /** The dataset. */ private Instances dataset; - protected Learner classifier; - - protected int parallelism; + protected AdaptiveLearner[] ensemble; /** * Sets the layout. */ protected void setLayout() { - - int sizeEnsemble = this.ensembleSizeOption.getValue(); + int ensembleSize = this.ensembleSizeOption.getValue(); distributorP = new BaggingDistributorProcessor(); - distributorP.setSizeEnsemble(sizeEnsemble); - this.builder.addProcessor(distributorP, 1); + distributorP.setEnsembleSize(ensembleSize); + builder.addProcessor(distributorP, 1); // instantiate classifier - classifier = this.baseLearnerOption.getValue(); - if (classifier instanceof AdaptiveLearner) { - // logger.info("Building an AdaptiveLearner {}", - // classifier.getClass().getName()); - AdaptiveLearner ada = (AdaptiveLearner) classifier; - ada.setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue()); + ensemble = new AdaptiveLearner[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + try { + ensemble[i] = (AdaptiveLearner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(), + baseLearnerOption.getRequiredType()); + } catch (Exception e) { + logger.error("Unable to create members of the ensemble. Please check your CLI parameters"); + e.printStackTrace(); + throw new IllegalArgumentException(e); + } + ensemble[i].setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue()); + ensemble[i].init(builder, this.dataset, 1); // sequential } - classifier.init(builder, this.dataset, sizeEnsemble); PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor(); - predictionCombinerP.setSizeEnsemble(sizeEnsemble); + predictionCombinerP.setEnsembleSize(ensembleSize); this.builder.addProcessor(predictionCombinerP, 1); // Streams - resultStream = this.builder.createStream(predictionCombinerP); + resultStream = builder.createStream(predictionCombinerP); predictionCombinerP.setOutputStream(resultStream); - for (Stream subResultStream : classifier.getResultStreams()) { - this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); + for (AdaptiveLearner member : ensemble) { + for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); // the key is the instance id to combine predictions + } } - /* The training stream. */ - Stream testingStream = this.builder.createStream(distributorP); - this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor()); - - /* The prediction stream. */ - Stream predictionStream = this.builder.createStream(distributorP); - this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor()); + ensembleStreams = new Stream[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + ensembleStreams[i] = builder.createStream(distributorP); + builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter) + } - distributorP.setOutputStream(testingStream); - distributorP.setPredictionStream(predictionStream); + distributorP.setOutputStreams(ensembleStreams); } /** The builder. */ @@ -131,7 +132,6 @@ public class AdaptiveBagging implements Learner, Configurable { public void init(TopologyBuilder builder, Instances dataset, int parallelism) { this.builder = builder; this.dataset = dataset; - this.parallelism = parallelism; this.setLayout(); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/fc44099a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java index 43bc07c..5d7bbfc 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java @@ -71,8 +71,6 @@ public class Bagging implements Learner, Configurable { protected Learner[] ensemble; - protected int parallelism; - /** * Sets the layout. * @@ -129,7 +127,6 @@ public class Bagging implements Learner, Configurable { public void init(TopologyBuilder builder, Instances dataset, int parallelism) { this.builder = builder; this.dataset = dataset; - this.parallelism = parallelism; this.setLayout(); }
