Repository: incubator-samoa Updated Branches: refs/heads/master a92b303de -> 4471fe4ae
SAMOA-35: Add Sharding ensemble method Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/4471fe4a Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/4471fe4a Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/4471fe4a Branch: refs/heads/master Commit: 4471fe4aedee822fd6948ed34fbaba4936671179 Parents: a92b303 Author: Gianmarco De Francisci Morales <[email protected]> Authored: Thu Jun 18 16:41:50 2015 +0300 Committer: Gianmarco De Francisci Morales <[email protected]> Committed: Tue Jul 4 15:29:42 2017 +0300 ---------------------------------------------------------------------- .../learners/classifiers/ensemble/Bagging.java | 3 +- .../learners/classifiers/ensemble/Sharding.java | 142 ++++++++++++++++ .../ensemble/ShardingDistributorProcessor.java | 161 +++++++++++++++++++ 3 files changed, 304 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/4471fe4a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java index 7178738..967684f 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java @@ -143,7 +143,6 @@ public class Bagging implements ClassificationLearner, Configurable { */ @Override public Set<Stream> getResultStreams() { - Set<Stream> streams = ImmutableSet.of(this.resultStream); - return streams; + return ImmutableSet.of(this.resultStream); } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/4471fe4a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java new file mode 100644 index 0000000..588d9f2 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java @@ -0,0 +1,142 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; +import com.google.common.collect.ImmutableSet; + +/** + * Simple sharding meta-classifier. It trains an ensemble of learners by shuffling the training stream among them, so + * that each learner is completely independent from each other. + */ +public class Sharding implements Learner, Configurable { + + private static final long serialVersionUID = -2971850264864952099L; + private static final Logger logger = LoggerFactory.getLogger(Sharding.class); + + /** The base learner class. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + /** The distributor processor. */ + private ShardingDistributorProcessor distributor; + + /** The input streams for the ensemble, one per member. */ + private Stream[] ensembleStreams; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected Learner[] ensemble; + + /** + * Sets the layout. + */ + protected void setLayout() { + + int ensembleSize = this.ensembleSizeOption.getValue(); + + distributor = new ShardingDistributorProcessor(); + distributor.setEnsembleSize(ensembleSize); + this.builder.addProcessor(distributor, 1); + + // instantiate classifier + ensemble = new Learner[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + try { + ensemble[i] = (Learner) ClassOption.createObject(baseLearnerOption.getValueAsCLIString(), + baseLearnerOption.getRequiredType()); + } catch (Exception e) { + logger.error("Unable to create members of the ensemble. Please check your CLI parameters"); + e.printStackTrace(); + throw new IllegalArgumentException(e); + } + ensemble[i].init(builder, this.dataset, 1); // sequential + } + + PredictionCombinerProcessor predictionCombiner = new PredictionCombinerProcessor(); + predictionCombiner.setEnsembleSize(ensembleSize); + this.builder.addProcessor(predictionCombiner, 1); + + // Streams + resultStream = this.builder.createStream(predictionCombiner); + predictionCombiner.setOutputStream(resultStream); + + for (Learner member : ensemble) { + for (Stream subResultStream : member.getResultStreams()) { // a learner can have multiple output streams + this.builder.connectInputKeyStream(subResultStream, predictionCombiner); // the key is the instance id to combine predictions + } + } + + ensembleStreams = new Stream[ensembleSize]; + for (int i = 0; i < ensembleSize; i++) { + ensembleStreams[i] = builder.createStream(distributor); + builder.connectInputShuffleStream(ensembleStreams[i], ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble members (the type of connection does not matter) + } + + distributor.setOutputStreams(ensembleStreams); + } + + /** The builder. */ + private TopologyBuilder builder; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributor; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/4471fe4a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java new file mode 100644 index 0000000..0e936d7 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java @@ -0,0 +1,161 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +import java.util.Arrays; +import java.util.Random; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.topology.Stream; + +/** + * The Class BaggingDistributorPE. + */ +public class ShardingDistributorProcessor implements Processor { + + private static final long serialVersionUID = -1550901409625192730L; + + /** The ensemble size. */ + private int ensembleSize; + + /** The stream ensemble. */ + private Stream[] ensembleStreams; + + /** Ramdom number generator. */ + protected Random random = new Random(); //TODO make random seed configurable + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + InstanceContentEvent inEvent = (InstanceContentEvent) event; + if (inEvent.isLastEvent()) { + // end learning + for (Stream stream : ensembleStreams) + stream.put(event); + return false; + } + + if (inEvent.isTesting()) { + Instance testInstance = inEvent.getInstance(); + for (int i = 0; i < ensembleSize; i++) { + Instance instanceCopy = testInstance.copy(); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy, + false, true); + instanceContentEvent.setClassifierIndex(i); //TODO probably not needed anymore + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); //TODO probably not needed anymore + ensembleStreams[i].put(instanceContentEvent); + } + } + + // estimate model parameters using the training data + if (inEvent.isTraining()) { + train(inEvent); + } + return false; + } + + /** + * Train. + * + * @param inEvent + * the in event + */ + protected void train(InstanceContentEvent inEvent) { + Instance trainInst = inEvent.getInstance().copy(); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent(inEvent.getInstanceIndex(), trainInst, + true, false); + int i = random.nextInt(ensembleSize); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + ensembleStreams[i].put(instanceContentEvent); + } + + /* + * (non-Javadoc) + * + * @see org.apache.s4.core.ProcessingElement#onCreate() + */ + @Override + public void onCreate(int id) { + // do nothing + } + + public Stream[] getOutputStreams() { + return ensembleStreams; + } + + public void setOutputStreams(Stream[] ensembleStreams) { + this.ensembleStreams = ensembleStreams; + } + + /** + * Gets the size ensemble. + * + * @return the size ensemble + */ + public int getEnsembleSize() { + return ensembleSize; + } + + /** + * Sets the size ensemble. + * + * @param ensembleSize + * the new size ensemble + */ + public void setEnsembleSize(int ensembleSize) { + this.ensembleSize = ensembleSize; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + ShardingDistributorProcessor newProcessor = new ShardingDistributorProcessor(); + ShardingDistributorProcessor originProcessor = (ShardingDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStreams() != null) { + newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(), + originProcessor.getOutputStreams().length)); + } + newProcessor.setEnsembleSize(originProcessor.getEnsembleSize()); + /* + * if (originProcessor.getLearningCurve() != null){ + * newProcessor.setLearningCurve((LearningCurve) + * originProcessor.getLearningCurve().copy()); } + */ + return newProcessor; + } +}
