Repository: incubator-samoa
Updated Branches:
  refs/heads/master a92b303de -> 4471fe4ae


SAMOA-35: Add Sharding ensemble method


Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/4471fe4a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/4471fe4a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/4471fe4a

Branch: refs/heads/master
Commit: 4471fe4aedee822fd6948ed34fbaba4936671179
Parents: a92b303
Author: Gianmarco De Francisci Morales <[email protected]>
Authored: Thu Jun 18 16:41:50 2015 +0300
Committer: Gianmarco De Francisci Morales <[email protected]>
Committed: Tue Jul 4 15:29:42 2017 +0300

----------------------------------------------------------------------
 .../learners/classifiers/ensemble/Bagging.java  |   3 +-
 .../learners/classifiers/ensemble/Sharding.java | 142 ++++++++++++++++
 .../ensemble/ShardingDistributorProcessor.java  | 161 +++++++++++++++++++
 3 files changed, 304 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/4471fe4a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
----------------------------------------------------------------------
diff --git 
a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
 
b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
index 7178738..967684f 100644
--- 
a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
+++ 
b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java
@@ -143,7 +143,6 @@ public class Bagging implements ClassificationLearner, 
Configurable {
    */
   @Override
   public Set<Stream> getResultStreams() {
-    Set<Stream> streams = ImmutableSet.of(this.resultStream);
-    return streams;
+    return ImmutableSet.of(this.resultStream);
   }
 }

http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/4471fe4a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java
----------------------------------------------------------------------
diff --git 
a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java
 
b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java
new file mode 100644
index 0000000..588d9f2
--- /dev/null
+++ 
b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Sharding.java
@@ -0,0 +1,142 @@
+package org.apache.samoa.learners.classifiers.ensemble;
+
+/*
+ * #%L
+ * SAMOA
+ * %%
+ * Copyright (C) 2014 - 2015 Apache Software Foundation
+ * %%
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * #L%
+ */
+
+import java.util.Set;
+
+import org.apache.samoa.core.Processor;
+import org.apache.samoa.instances.Instances;
+import org.apache.samoa.learners.Learner;
+import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree;
+import org.apache.samoa.topology.Stream;
+import org.apache.samoa.topology.TopologyBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.github.javacliparser.ClassOption;
+import com.github.javacliparser.Configurable;
+import com.github.javacliparser.IntOption;
+import com.google.common.collect.ImmutableSet;
+
+/**
+ * Simple sharding meta-classifier. It trains an ensemble of learners by 
shuffling the training stream among them, so
+ * that each learner is completely independent from each other.
+ */
+public class Sharding implements Learner, Configurable {
+
+  private static final long serialVersionUID = -2971850264864952099L;
+  private static final Logger logger = LoggerFactory.getLogger(Sharding.class);
+
+  /** The base learner class. */
+  public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l',
+      "Classifier to train.", Learner.class, 
VerticalHoeffdingTree.class.getName());
+
+  /** The ensemble size option. */
+  public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's',
+      "The number of models in the bag.", 10, 1, Integer.MAX_VALUE);
+
+  /** The distributor processor. */
+  private ShardingDistributorProcessor distributor;
+
+  /** The input streams for the ensemble, one per member. */
+  private Stream[] ensembleStreams;
+
+  /** The result stream. */
+  protected Stream resultStream;
+
+  /** The dataset. */
+  private Instances dataset;
+
+  protected Learner[] ensemble;
+
+  /**
+   * Sets the layout.
+   */
+  protected void setLayout() {
+
+    int ensembleSize = this.ensembleSizeOption.getValue();
+
+    distributor = new ShardingDistributorProcessor();
+    distributor.setEnsembleSize(ensembleSize);
+    this.builder.addProcessor(distributor, 1);
+
+    // instantiate classifier
+    ensemble = new Learner[ensembleSize];
+    for (int i = 0; i < ensembleSize; i++) {
+      try {
+        ensemble[i] = (Learner) 
ClassOption.createObject(baseLearnerOption.getValueAsCLIString(),
+            baseLearnerOption.getRequiredType());
+      } catch (Exception e) {
+        logger.error("Unable to create members of the ensemble. Please check 
your CLI parameters");
+        e.printStackTrace();
+        throw new IllegalArgumentException(e);
+      }
+      ensemble[i].init(builder, this.dataset, 1); // sequential
+    }
+
+    PredictionCombinerProcessor predictionCombiner = new 
PredictionCombinerProcessor();
+    predictionCombiner.setEnsembleSize(ensembleSize);
+    this.builder.addProcessor(predictionCombiner, 1);
+
+    // Streams
+    resultStream = this.builder.createStream(predictionCombiner);
+    predictionCombiner.setOutputStream(resultStream);
+
+    for (Learner member : ensemble) {
+      for (Stream subResultStream : member.getResultStreams()) { // a learner 
can have multiple output streams
+        this.builder.connectInputKeyStream(subResultStream, 
predictionCombiner); // the key is the instance id to combine predictions
+      }
+    }
+
+    ensembleStreams = new Stream[ensembleSize];
+    for (int i = 0; i < ensembleSize; i++) {
+      ensembleStreams[i] = builder.createStream(distributor);
+      builder.connectInputShuffleStream(ensembleStreams[i], 
ensemble[i].getInputProcessor()); // connect streams one-to-one with ensemble 
members (the type of connection does not matter)
+    }
+    
+    distributor.setOutputStreams(ensembleStreams);
+  }
+
+  /** The builder. */
+  private TopologyBuilder builder;
+
+  @Override
+  public void init(TopologyBuilder builder, Instances dataset, int 
parallelism) {
+    this.builder = builder;
+    this.dataset = dataset;
+    this.setLayout();
+  }
+
+  @Override
+  public Processor getInputProcessor() {
+    return distributor;
+  }
+
+  /*
+   * (non-Javadoc)
+   * 
+   * @see samoa.learners.Learner#getResultStreams()
+   */
+  @Override
+  public Set<Stream> getResultStreams() {
+    return ImmutableSet.of(this.resultStream);
+  }
+}

http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/4471fe4a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java
----------------------------------------------------------------------
diff --git 
a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java
 
b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java
new file mode 100644
index 0000000..0e936d7
--- /dev/null
+++ 
b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/ShardingDistributorProcessor.java
@@ -0,0 +1,161 @@
+package org.apache.samoa.learners.classifiers.ensemble;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/*
+ * #%L
+ * SAMOA
+ * %%
+ * Copyright (C) 2014 - 2015 Apache Software Foundation
+ * %%
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * #L%
+ */
+
+/**
+ * License
+ */
+
+import org.apache.samoa.core.ContentEvent;
+import org.apache.samoa.core.Processor;
+import org.apache.samoa.instances.Instance;
+import org.apache.samoa.learners.InstanceContentEvent;
+import org.apache.samoa.topology.Stream;
+
+/**
+ * The Class BaggingDistributorPE.
+ */
+public class ShardingDistributorProcessor implements Processor {
+
+  private static final long serialVersionUID = -1550901409625192730L;
+
+  /** The ensemble size. */
+  private int ensembleSize;
+
+  /** The stream ensemble. */
+  private Stream[] ensembleStreams;
+
+  /** Ramdom number generator. */
+  protected Random random = new Random(); //TODO make random seed configurable
+
+  /**
+   * On event.
+   * 
+   * @param event
+   *          the event
+   * @return true, if successful
+   */
+  public boolean process(ContentEvent event) {
+    InstanceContentEvent inEvent = (InstanceContentEvent) event;
+    if (inEvent.isLastEvent()) {
+      // end learning
+      for (Stream stream : ensembleStreams)
+        stream.put(event);
+      return false;
+    }
+
+    if (inEvent.isTesting()) {
+      Instance testInstance = inEvent.getInstance();
+      for (int i = 0; i < ensembleSize; i++) {
+        Instance instanceCopy = testInstance.copy();
+        InstanceContentEvent instanceContentEvent = new 
InstanceContentEvent(inEvent.getInstanceIndex(), instanceCopy,
+            false, true);
+        instanceContentEvent.setClassifierIndex(i); //TODO probably not needed 
anymore
+        instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); 
//TODO probably not needed anymore
+        ensembleStreams[i].put(instanceContentEvent);
+      }
+    }
+
+    // estimate model parameters using the training data
+    if (inEvent.isTraining()) {
+      train(inEvent);
+    }
+    return false;
+  }
+
+  /**
+   * Train.
+   * 
+   * @param inEvent
+   *          the in event
+   */
+  protected void train(InstanceContentEvent inEvent) {
+    Instance trainInst = inEvent.getInstance().copy();
+    InstanceContentEvent instanceContentEvent = new 
InstanceContentEvent(inEvent.getInstanceIndex(), trainInst,
+        true, false);
+    int i = random.nextInt(ensembleSize);
+    instanceContentEvent.setClassifierIndex(i);
+    instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex());
+    ensembleStreams[i].put(instanceContentEvent);
+  }
+
+  /*
+   * (non-Javadoc)
+   * 
+   * @see org.apache.s4.core.ProcessingElement#onCreate()
+   */
+  @Override
+  public void onCreate(int id) {
+    // do nothing
+  }
+
+  public Stream[] getOutputStreams() {
+    return ensembleStreams;
+  }
+
+  public void setOutputStreams(Stream[] ensembleStreams) {
+    this.ensembleStreams = ensembleStreams;
+  }
+
+  /**
+   * Gets the size ensemble.
+   * 
+   * @return the size ensemble
+   */
+  public int getEnsembleSize() {
+    return ensembleSize;
+  }
+
+  /**
+   * Sets the size ensemble.
+   * 
+   * @param ensembleSize
+   *          the new size ensemble
+   */
+  public void setEnsembleSize(int ensembleSize) {
+    this.ensembleSize = ensembleSize;
+  }
+
+  /*
+   * (non-Javadoc)
+   * 
+   * @see samoa.core.Processor#newProcessor(samoa.core.Processor)
+   */
+  @Override
+  public Processor newProcessor(Processor sourceProcessor) {
+    ShardingDistributorProcessor newProcessor = new 
ShardingDistributorProcessor();
+    ShardingDistributorProcessor originProcessor = 
(ShardingDistributorProcessor) sourceProcessor;
+    if (originProcessor.getOutputStreams() != null) {
+      
newProcessor.setOutputStreams(Arrays.copyOf(originProcessor.getOutputStreams(),
+          originProcessor.getOutputStreams().length));
+    }
+    newProcessor.setEnsembleSize(originProcessor.getEnsembleSize());
+    /*
+     * if (originProcessor.getLearningCurve() != null){
+     * newProcessor.setLearningCurve((LearningCurve)
+     * originProcessor.getLearningCurve().copy()); }
+     */
+    return newProcessor;
+  }
+}

Reply via email to