http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NaiveBayes.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NaiveBayes.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NaiveBayes.java new file mode 100644 index 0000000..df24cd5 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/NaiveBayes.java @@ -0,0 +1,264 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.HashMap; +import java.util.Map; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; +import org.apache.samoa.moa.core.GaussianEstimator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementation of a non-distributed Naive Bayes classifier. + * + * At the moment, the implementation models all attributes as numeric attributes. + * + * @author Olivier Van Laere (vanlaere yahoo-inc dot com) + */ +public class NaiveBayes implements LocalLearner { + + /** + * Default smoothing factor. For now fixed to 1E-20. + */ + private static final double ADDITIVE_SMOOTHING_FACTOR = 1e-20; + + /** + * serialVersionUID for serialization + */ + private static final long serialVersionUID = 1325775209672996822L; + + /** + * Instance of a logger for use in this class. + */ + private static final Logger logger = LoggerFactory.getLogger(NaiveBayes.class); + + /** + * The actual model. + */ + protected Map<Integer, GaussianNumericAttributeClassObserver> attributeObservers; + + /** + * Class statistics + */ + protected Map<Integer, Double> classInstances; + + /** + * Class zero-prototypes. + */ + protected Map<Integer, Double> classPrototypes; + + /** + * Retrieve the number of classes currently known to this local model + * + * @return the number of classes currently known to this local model + */ + protected int getNumberOfClasses() { + return this.classInstances.size(); + } + + /** + * Track training instances seen. + */ + protected long instancesSeen = 0L; + + /** + * Explicit no-arg constructor. + */ + public NaiveBayes() { + // Init the model + resetLearning(); + } + + /** + * Create an instance of this LocalLearner implementation. + */ + @Override + public LocalLearner create() { + return new NaiveBayes(); + } + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * will be all zero. + * + * Smoothing is being implemented by the AttributeClassObserver classes. At the moment, the + * GaussianNumericProbabilityAttributeClassObserver needs no smoothing as it processes continuous variables. + * + * Please note that we transform the scores to log space to avoid underflow, and we replace the multiplication with + * addition. + * + * The resulting scores are no longer probabilities, as a mixture of probability densities and probabilities can be + * used in the computation. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership scores of the test instance in each class, in log space. + */ + @Override + public double[] getVotesForInstance(Instance inst) { + // Prepare the results array + double[] votes = new double[getNumberOfClasses()]; + // Over all classes + for (int classIndex = 0; classIndex < votes.length; classIndex++) { + // Get the prior for this class + votes[classIndex] = Math.log(getPrior(classIndex)); + // Iterate over the instance attributes + for (int index = 0; index < inst.numAttributes(); index++) { + int attributeID = inst.index(index); + // Skip class attribute + if (attributeID == inst.classIndex()) + continue; + Double value = inst.value(attributeID); + // Get the observer for the given attribute + GaussianNumericAttributeClassObserver obs = attributeObservers.get(attributeID); + // Init the estimator to null by default + GaussianEstimator estimator = null; + if (obs != null && obs.getEstimator(classIndex) != null) { + // Get the estimator + estimator = obs.getEstimator(classIndex); + } + double valueNonZero; + // The null case should be handled by smoothing! + if (estimator != null) { + // Get the score for a NON-ZERO attribute value + valueNonZero = estimator.probabilityDensity(value); + } + // We don't have an estimator + else { + // Assign a very small probability that we do see this value + valueNonZero = ADDITIVE_SMOOTHING_FACTOR; + } + votes[classIndex] += Math.log(valueNonZero); // - Math.log(valueZero); + } + // Check for null in the case of prequential evaluation + if (this.classPrototypes.get(classIndex) != null) { + // Add the prototype for the class, already in log space + votes[classIndex] += Math.log(this.classPrototypes.get(classIndex)); + } + } + return votes; + } + + /** + * Compute the prior for the given classIndex. + * + * Implemented by maximum likelihood at the moment. + * + * @param classIndex + * Id of the class for which we want to compute the prior. + * @return Prior probability for the requested class + */ + private double getPrior(int classIndex) { + // Maximum likelihood + Double currentCount = this.classInstances.get(classIndex); + if (currentCount == null || currentCount == 0) + return 0; + else + return currentCount * 1. / this.instancesSeen; + } + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + */ + @Override + public void resetLearning() { + // Reset priors + this.instancesSeen = 0L; + this.classInstances = new HashMap<>(); + this.classPrototypes = new HashMap<>(); + // Init the attribute observers + this.attributeObservers = new HashMap<>(); + } + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + @Override + public void trainOnInstance(Instance inst) { + // Update class statistics with weights + int classIndex = (int) inst.classValue(); + Double weight = this.classInstances.get(classIndex); + if (weight == null) + weight = 0.; + this.classInstances.put(classIndex, weight + inst.weight()); + + // Get the class prototype + Double classPrototype = this.classPrototypes.get(classIndex); + if (classPrototype == null) + classPrototype = 1.; + + // Iterate over the attributes of the given instance + for (int attributePosition = 0; attributePosition < inst + .numAttributes(); attributePosition++) { + // Get the attribute index - Dense -> 1:1, Sparse is remapped + int attributeID = inst.index(attributePosition); + // Skip class attribute + if (attributeID == inst.classIndex()) + continue; + // Get the attribute observer for the current attribute + GaussianNumericAttributeClassObserver obs = this.attributeObservers + .get(attributeID); + // Lazy init of observers, if null, instantiate a new one + if (obs == null) { + // FIXME: At this point, we model everything as a numeric + // attribute + obs = new GaussianNumericAttributeClassObserver(); + this.attributeObservers.put(attributeID, obs); + } + + // Get the probability density function under the current model + GaussianEstimator obs_estimator = obs.getEstimator(classIndex); + if (obs_estimator != null) { + // Fetch the probability that the feature value is zero + double probDens_zero_current = obs_estimator.probabilityDensity(0); + classPrototype -= probDens_zero_current; + } + + // FIXME: Sanity check on data values, for now just learn + // Learn attribute value for given class + obs.observeAttributeClass(inst.valueSparse(attributePosition), + (int) inst.classValue(), inst.weight()); + + // Update obs_estimator to fetch the pdf from the updated model + obs_estimator = obs.getEstimator(classIndex); + // Fetch the probability that the feature value is zero + double probDens_zero_updated = obs_estimator.probabilityDensity(0); + // Update the class prototype + classPrototype += probDens_zero_updated; + } + // Store the class prototype + this.classPrototypes.put(classIndex, classPrototype); + // Count another training instance + this.instancesSeen++; + } + + @Override + public void setDataset(Instances dataset) { + // Do nothing + } +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SimpleClassifierAdapter.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SimpleClassifierAdapter.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SimpleClassifierAdapter.java new file mode 100644 index 0000000..8db8482 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SimpleClassifierAdapter.java @@ -0,0 +1,153 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * License + */ +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.instances.InstancesHeader; +import org.apache.samoa.moa.classifiers.functions.MajorityClass; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Base class for adapting external classifiers. + * + */ +public class SimpleClassifierAdapter implements LocalLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 4372366401338704353L; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Classifier to train.", org.apache.samoa.moa.classifiers.Classifier.class, MajorityClass.class.getName()); + /** + * The learner. + */ + protected org.apache.samoa.moa.classifiers.Classifier learner; + + /** + * The is init. + */ + protected Boolean isInit; + + /** + * The dataset. + */ + protected Instances dataset; + + @Override + public void setDataset(Instances dataset) { + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner + * the learner + * @param dataset + * the dataset + */ + public SimpleClassifierAdapter(org.apache.samoa.moa.classifiers.Classifier learner, Instances dataset) { + this.learner = learner.copy(); + this.isInit = false; + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + */ + public SimpleClassifierAdapter() { + this.learner = ((org.apache.samoa.moa.classifiers.Classifier) this.learnerOption.getValue()).copy(); + this.isInit = false; + } + + /** + * Creates a new learner object. + * + * @return the learner + */ + @Override + public SimpleClassifierAdapter create() { + SimpleClassifierAdapter l = new SimpleClassifierAdapter(learner, dataset); + if (dataset == null) { + System.out.println("dataset null while creating"); + } + return l; + } + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + @Override + public void trainOnInstance(Instance inst) { + if (!this.isInit) { + this.isInit = true; + InstancesHeader instances = new InstancesHeader(dataset); + this.learner.setModelContext(instances); + this.learner.prepareForUse(); + } + if (inst.weight() > 0) { + inst.setDataset(dataset); + learner.trainOnInstance(inst); + } + } + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + @Override + public double[] getVotesForInstance(Instance inst) { + double[] ret; + inst.setDataset(dataset); + if (!this.isInit) { + ret = new double[dataset.numClasses()]; + } else { + ret = learner.getVotesForInstance(inst); + } + return ret; + } + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + * + */ + @Override + public void resetLearning() { + learner.resetLearning(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SingleClassifier.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SingleClassifier.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SingleClassifier.java new file mode 100644 index 0000000..5c989f3 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/SingleClassifier.java @@ -0,0 +1,112 @@ +package org.apache.samoa.learners.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.AdaptiveLearner; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Classifier that contain a single classifier. + * + */ +public final class SingleClassifier implements Learner, AdaptiveLearner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private LocalLearnerProcessor learnerP; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Classifier to train.", LocalLearner.class, SimpleClassifierAdapter.class.getName()); + + private TopologyBuilder builder; + + private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + learnerP = new LocalLearnerProcessor(); + learnerP.setChangeDetector(this.getChangeDetector()); + LocalLearner learner = this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + + // learnerPI = this.builder.createPi(learnerP, 1); + this.builder.addProcessor(learnerP, parallelism); + resultStream = this.builder.createStream(learnerP); + + learnerP.setOutputStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + return learnerP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } + + protected ChangeDetector changeDetector; + + @Override + public ChangeDetector getChangeDetector() { + return this.changeDetector; + } + + @Override + public void setChangeDetector(ChangeDetector cd) { + this.changeDetector = cd; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java new file mode 100644 index 0000000..9ffba2a --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/AdaptiveBagging.java @@ -0,0 +1,152 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.AdaptiveLearner; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; +import org.apache.samoa.moa.classifiers.core.driftdetection.ADWINChangeDetector; +import org.apache.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; + +/** + * The Bagging Classifier by Oza and Russell. + */ +public class AdaptiveBagging implements Learner, Configurable { + + /** Logger */ + private static final Logger logger = LoggerFactory.getLogger(AdaptiveBagging.class); + + /** The Constant serialVersionUID. */ + private static final long serialVersionUID = -2971850264864952099L; + + /** The base learner option. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + public ClassOption driftDetectionMethodOption = new ClassOption("driftDetectionMethod", 'd', + "Drift detection method to use.", ChangeDetector.class, ADWINChangeDetector.class.getName()); + + /** The distributor processor. */ + private BaggingDistributorProcessor distributorP; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected Learner classifier; + + protected int parallelism; + + /** + * Sets the layout. + */ + protected void setLayout() { + + int sizeEnsemble = this.ensembleSizeOption.getValue(); + + distributorP = new BaggingDistributorProcessor(); + distributorP.setSizeEnsemble(sizeEnsemble); + this.builder.addProcessor(distributorP, 1); + + // instantiate classifier + classifier = this.baseLearnerOption.getValue(); + if (classifier instanceof AdaptiveLearner) { + // logger.info("Building an AdaptiveLearner {}", + // classifier.getClass().getName()); + AdaptiveLearner ada = (AdaptiveLearner) classifier; + ada.setChangeDetector((ChangeDetector) this.driftDetectionMethodOption.getValue()); + } + classifier.init(builder, this.dataset, sizeEnsemble); + + PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor(); + predictionCombinerP.setSizeEnsemble(sizeEnsemble); + this.builder.addProcessor(predictionCombinerP, 1); + + // Streams + resultStream = this.builder.createStream(predictionCombinerP); + predictionCombinerP.setOutputStream(resultStream); + + for (Stream subResultStream : classifier.getResultStreams()) { + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); + } + + /* The training stream. */ + Stream testingStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor()); + + /* The prediction stream. */ + Stream predictionStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor()); + + distributorP.setOutputStream(testingStream); + distributorP.setPredictionStream(predictionStream); + } + + /** The builder. */ + private TopologyBuilder builder; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributorP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java new file mode 100644 index 0000000..7355b1a --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Bagging.java @@ -0,0 +1,140 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.trees.VerticalHoeffdingTree; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; + +/** + * The Bagging Classifier by Oza and Russell. + */ +public class Bagging implements Learner, Configurable { + + /** The Constant serialVersionUID. */ + private static final long serialVersionUID = -2971850264864952099L; + + /** The base learner option. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Learner.class, VerticalHoeffdingTree.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + /** The distributor processor. */ + private BaggingDistributorProcessor distributorP; + + /** The training stream. */ + private Stream testingStream; + + /** The prediction stream. */ + private Stream predictionStream; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected Learner classifier; + + protected int parallelism; + + /** + * Sets the layout. + */ + protected void setLayout() { + + int sizeEnsemble = this.ensembleSizeOption.getValue(); + + distributorP = new BaggingDistributorProcessor(); + distributorP.setSizeEnsemble(sizeEnsemble); + this.builder.addProcessor(distributorP, 1); + + // instantiate classifier + classifier = (Learner) this.baseLearnerOption.getValue(); + classifier.init(builder, this.dataset, sizeEnsemble); + + PredictionCombinerProcessor predictionCombinerP = new PredictionCombinerProcessor(); + predictionCombinerP.setSizeEnsemble(sizeEnsemble); + this.builder.addProcessor(predictionCombinerP, 1); + + // Streams + resultStream = this.builder.createStream(predictionCombinerP); + predictionCombinerP.setOutputStream(resultStream); + + for (Stream subResultStream : classifier.getResultStreams()) { + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); + } + + testingStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor()); + + predictionStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor()); + + distributorP.setOutputStream(testingStream); + distributorP.setPredictionStream(predictionStream); + } + + /** The builder. */ + private TopologyBuilder builder; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributorP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.resultStream); + return streams; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java new file mode 100644 index 0000000..33615db --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BaggingDistributorProcessor.java @@ -0,0 +1,208 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import java.util.Random; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.moa.core.MiscUtils; +import org.apache.samoa.topology.Stream; + +/** + * The Class BaggingDistributorPE. + */ +public class BaggingDistributorProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -1550901409625192730L; + + /** The size ensemble. */ + private int sizeEnsemble; + + /** The training stream. */ + private Stream trainingStream; + + /** The prediction stream. */ + private Stream predictionStream; + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + InstanceContentEvent inEvent = (InstanceContentEvent) event; // ((s4Event)event).getContentEvent(); + // InstanceEvent inEvent = (InstanceEvent) event; + + if (inEvent.getInstanceIndex() < 0) { + // End learning + predictionStream.put(event); + return false; + } + + if (inEvent.isTesting()) { + Instance trainInst = inEvent.getInstance(); + for (int i = 0; i < sizeEnsemble; i++) { + Instance weightedInst = trainInst.copy(); + // weightedInst.setWeight(trainInst.weight() * k); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent( + inEvent.getInstanceIndex(), weightedInst, false, true); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + predictionStream.put(instanceContentEvent); + } + } + + /* Estimate model parameters using the training data. */ + if (inEvent.isTraining()) { + train(inEvent); + } + return false; + } + + /** The random. */ + protected Random random = new Random(); + + /** + * Train. + * + * @param inEvent + * the in event + */ + protected void train(InstanceContentEvent inEvent) { + Instance trainInst = inEvent.getInstance(); + for (int i = 0; i < sizeEnsemble; i++) { + int k = MiscUtils.poisson(1.0, this.random); + if (k > 0) { + Instance weightedInst = trainInst.copy(); + weightedInst.setWeight(trainInst.weight() * k); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent( + inEvent.getInstanceIndex(), weightedInst, true, false); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + trainingStream.put(instanceContentEvent); + } + } + } + + /* + * (non-Javadoc) + * + * @see org.apache.s4.core.ProcessingElement#onCreate() + */ + @Override + public void onCreate(int id) { + // do nothing + } + + /** + * Gets the training stream. + * + * @return the training stream + */ + public Stream getTrainingStream() { + return trainingStream; + } + + /** + * Sets the training stream. + * + * @param trainingStream + * the new training stream + */ + public void setOutputStream(Stream trainingStream) { + this.trainingStream = trainingStream; + } + + /** + * Gets the prediction stream. + * + * @return the prediction stream + */ + public Stream getPredictionStream() { + return predictionStream; + } + + /** + * Sets the prediction stream. + * + * @param predictionStream + * the new prediction stream + */ + public void setPredictionStream(Stream predictionStream) { + this.predictionStream = predictionStream; + } + + /** + * Gets the size ensemble. + * + * @return the size ensemble + */ + public int getSizeEnsemble() { + return sizeEnsemble; + } + + /** + * Sets the size ensemble. + * + * @param sizeEnsemble + * the new size ensemble + */ + public void setSizeEnsemble(int sizeEnsemble) { + this.sizeEnsemble = sizeEnsemble; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + BaggingDistributorProcessor newProcessor = new BaggingDistributorProcessor(); + BaggingDistributorProcessor originProcessor = (BaggingDistributorProcessor) sourceProcessor; + if (originProcessor.getPredictionStream() != null) { + newProcessor.setPredictionStream(originProcessor.getPredictionStream()); + } + if (originProcessor.getTrainingStream() != null) { + newProcessor.setOutputStream(originProcessor.getTrainingStream()); + } + newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble()); + /* + * if (originProcessor.getLearningCurve() != null){ + * newProcessor.setLearningCurve((LearningCurve) + * originProcessor.getLearningCurve().copy()); } + */ + return newProcessor; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Boosting.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Boosting.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Boosting.java new file mode 100644 index 0000000..14cd98b --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/Boosting.java @@ -0,0 +1,149 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.classifiers.SingleClassifier; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; + +/** + * The Bagging Classifier by Oza and Russell. + */ +public class Boosting implements Learner, Configurable { + + /** The Constant serialVersionUID. */ + private static final long serialVersionUID = -2971850264864952099L; + + /** The base learner option. */ + public ClassOption baseLearnerOption = new ClassOption("baseLearner", 'l', + "Classifier to train.", Learner.class, SingleClassifier.class.getName()); + + /** The ensemble size option. */ + public IntOption ensembleSizeOption = new IntOption("ensembleSize", 's', + "The number of models in the bag.", 10, 1, Integer.MAX_VALUE); + + /** The distributor processor. */ + private BoostingDistributorProcessor distributorP; + + /** The result stream. */ + protected Stream resultStream; + + /** The dataset. */ + private Instances dataset; + + protected Learner classifier; + + protected int parallelism; + + /** + * Sets the layout. + */ + protected void setLayout() { + + int sizeEnsemble = this.ensembleSizeOption.getValue(); + + distributorP = new BoostingDistributorProcessor(); + distributorP.setSizeEnsemble(sizeEnsemble); + this.builder.addProcessor(distributorP, 1); + + // instantiate classifier + classifier = this.baseLearnerOption.getValue(); + classifier.init(builder, this.dataset, sizeEnsemble); + + BoostingPredictionCombinerProcessor predictionCombinerP = new BoostingPredictionCombinerProcessor(); + predictionCombinerP.setSizeEnsemble(sizeEnsemble); + this.builder.addProcessor(predictionCombinerP, 1); + + // Streams + resultStream = this.builder.createStream(predictionCombinerP); + predictionCombinerP.setOutputStream(resultStream); + + for (Stream subResultStream : classifier.getResultStreams()) { + this.builder.connectInputKeyStream(subResultStream, predictionCombinerP); + } + + /* The testing stream. */ + Stream testingStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(testingStream, classifier.getInputProcessor()); + + /* The prediction stream. */ + Stream predictionStream = this.builder.createStream(distributorP); + this.builder.connectInputKeyStream(predictionStream, classifier.getInputProcessor()); + + distributorP.setOutputStream(testingStream); + distributorP.setPredictionStream(predictionStream); + + // Addition to Bagging: stream to train + /* The training stream. */ + Stream trainingStream = this.builder.createStream(predictionCombinerP); + predictionCombinerP.setTrainingStream(trainingStream); + this.builder.connectInputKeyStream(trainingStream, classifier.getInputProcessor()); + + } + + /** The builder. */ + private TopologyBuilder builder; + + /* + * (non-Javadoc) + * + * @see samoa.classifiers.Classifier#init(samoa.engines.Engine, + * samoa.core.Stream, weka.core.Instances) + */ + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + @Override + public Processor getInputProcessor() { + return distributorP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java new file mode 100644 index 0000000..bcfb853 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingDistributorProcessor.java @@ -0,0 +1,35 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +import org.apache.samoa.learners.InstanceContentEvent; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * The Class BoostingDistributorProcessor. + */ +public class BoostingDistributorProcessor extends BaggingDistributorProcessor { + + @Override + protected void train(InstanceContentEvent inEvent) { + // Boosting is trained from the prediction combiner, not from the input + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java new file mode 100644 index 0000000..6cfcfae --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java @@ -0,0 +1,178 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.moa.core.Utils; +import org.apache.samoa.topology.Stream; + +/** + * The Class BoostingPredictionCombinerProcessor. + */ +public class BoostingPredictionCombinerProcessor extends PredictionCombinerProcessor { + + private static final long serialVersionUID = -1606045723451191232L; + + // Weigths classifier + protected double[] scms; + + // Weights instance + protected double[] swms; + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + @Override + public boolean process(ContentEvent event) { + + ResultContentEvent inEvent = (ResultContentEvent) event; + double[] prediction = inEvent.getClassVotes(); + int instanceIndex = (int) inEvent.getInstanceIndex(); + + addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); + // Boosting + addPredictions(instanceIndex, inEvent, prediction); + + if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) { + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null) { + combinedVote = new DoubleVector(); + } + ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), + inEvent.getInstance(), inEvent.getClassId(), + combinedVote.getArrayCopy(), inEvent.isLastEvent()); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + outputStream.put(outContentEvent); + clearStatisticsInstance(instanceIndex); + // Boosting + computeBoosting(inEvent, instanceIndex); + return true; + } + return false; + + } + + protected Random random; + + protected int trainingWeightSeenByModel; + + @Override + protected double getEnsembleMemberWeight(int i) { + double em = this.swms[i] / (this.scms[i] + this.swms[i]); + if ((em == 0.0) || (em > 0.5)) { + return 0.0; + } + double Bm = em / (1.0 - em); + return Math.log(1.0 / Bm); + } + + @Override + public void reset() { + this.random = new Random(); + this.trainingWeightSeenByModel = 0; + this.scms = new double[this.ensembleSize]; + this.swms = new double[this.ensembleSize]; + } + + private boolean correctlyClassifies(int i, Instance inst, int instanceIndex) { + int predictedClass = (int) mapPredictions.get(instanceIndex).getValue(i); + return predictedClass == (int) inst.classValue(); + } + + protected Map<Integer, DoubleVector> mapPredictions; + + private void addPredictions(int instanceIndex, ResultContentEvent inEvent, double[] prediction) { + if (this.mapPredictions == null) { + this.mapPredictions = new HashMap<>(); + } + DoubleVector predictions = this.mapPredictions.get(instanceIndex); + if (predictions == null) { + predictions = new DoubleVector(); + } + predictions.setValue(inEvent.getClassifierIndex(), Utils.maxIndex(prediction)); + this.mapPredictions.put(instanceIndex, predictions); + } + + private void computeBoosting(ResultContentEvent inEvent, int instanceIndex) { + // Starts code for Boosting + // Send instances to train + double lambda_d = 1.0; + for (int i = 0; i < this.ensembleSize; i++) { + double k = lambda_d; + Instance inst = inEvent.getInstance(); + if (k > 0.0) { + Instance weightedInst = inst.copy(); + weightedInst.setWeight(inst.weight() * k); + // this.ensemble[i].trainOnInstance(weightedInst); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent( + inEvent.getInstanceIndex(), weightedInst, true, false); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + trainingStream.put(instanceContentEvent); + } + if (this.correctlyClassifies(i, inst, instanceIndex)) { + this.scms[i] += lambda_d; + lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]); + } else { + this.swms[i] += lambda_d; + lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]); + } + } + } + + /** + * Gets the training stream. + * + * @return the training stream + */ + public Stream getTrainingStream() { + return trainingStream; + } + + /** + * Sets the training stream. + * + * @param trainingStream + * the new training stream + */ + public void setTrainingStream(Stream trainingStream) { + this.trainingStream = trainingStream; + } + + /** The training stream. */ + private Stream trainingStream; + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java new file mode 100644 index 0000000..2e5f335 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java @@ -0,0 +1,187 @@ +package org.apache.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ +import java.util.HashMap; +import java.util.Map; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.topology.Stream; + +/** + * The Class PredictionCombinerProcessor. + */ +public class PredictionCombinerProcessor implements Processor { + + private static final long serialVersionUID = -1606045723451191132L; + + /** + * The size ensemble. + */ + protected int ensembleSize; + + /** + * The output stream. + */ + protected Stream outputStream; + + /** + * Sets the output stream. + * + * @param stream + * the new output stream + */ + public void setOutputStream(Stream stream) { + outputStream = stream; + } + + /** + * Gets the output stream. + * + * @return the output stream + */ + public Stream getOutputStream() { + return outputStream; + } + + /** + * Gets the size ensemble. + * + * @return the ensembleSize + */ + public int getSizeEnsemble() { + return ensembleSize; + } + + /** + * Sets the size ensemble. + * + * @param ensembleSize + * the new size ensemble + */ + public void setSizeEnsemble(int ensembleSize) { + this.ensembleSize = ensembleSize; + } + + protected Map<Integer, Integer> mapCountsforInstanceReceived; + + protected Map<Integer, DoubleVector> mapVotesforInstanceReceived; + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + + ResultContentEvent inEvent = (ResultContentEvent) event; + double[] prediction = inEvent.getClassVotes(); + int instanceIndex = (int) inEvent.getInstanceIndex(); + + addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); + + if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) { + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null) { + combinedVote = new DoubleVector(new double[inEvent.getInstance().numClasses()]); + } + ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), + inEvent.getInstance(), inEvent.getClassId(), + combinedVote.getArrayCopy(), inEvent.isLastEvent()); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + outputStream.put(outContentEvent); + clearStatisticsInstance(instanceIndex); + return true; + } + return false; + + } + + @Override + public void onCreate(int id) { + this.reset(); + } + + public void reset() { + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + PredictionCombinerProcessor newProcessor = new PredictionCombinerProcessor(); + PredictionCombinerProcessor originProcessor = (PredictionCombinerProcessor) sourceProcessor; + if (originProcessor.getOutputStream() != null) { + newProcessor.setOutputStream(originProcessor.getOutputStream()); + } + newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble()); + return newProcessor; + } + + protected void addStatisticsForInstanceReceived(int instanceIndex, int classifierIndex, double[] prediction, int add) { + if (this.mapCountsforInstanceReceived == null) { + this.mapCountsforInstanceReceived = new HashMap<>(); + this.mapVotesforInstanceReceived = new HashMap<>(); + } + DoubleVector vote = new DoubleVector(prediction); + if (vote.sumOfValues() > 0.0) { + vote.normalize(); + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null) { + combinedVote = new DoubleVector(); + } + vote.scaleValues(getEnsembleMemberWeight(classifierIndex)); + combinedVote.addValues(vote); + + this.mapVotesforInstanceReceived.put(instanceIndex, combinedVote); + } + Integer count = this.mapCountsforInstanceReceived.get(instanceIndex); + if (count == null) { + count = 0; + } + this.mapCountsforInstanceReceived.put(instanceIndex, count + add); + } + + protected boolean hasAllVotesArrivedInstance(int instanceIndex) { + return (this.mapCountsforInstanceReceived.get(instanceIndex) == this.ensembleSize); + } + + protected void clearStatisticsInstance(int instanceIndex) { + this.mapCountsforInstanceReceived.remove(instanceIndex); + this.mapVotesforInstanceReceived.remove(instanceIndex); + } + + protected double getEnsembleMemberWeight(int i) { + return 1.0; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesRegressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesRegressor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesRegressor.java new file mode 100644 index 0000000..58d3eb6 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/AMRulesRegressor.java @@ -0,0 +1,177 @@ +package org.apache.samoa.learners.classifiers.rules; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.RegressionLearner; +import org.apache.samoa.learners.classifiers.rules.centralized.AMRulesRegressorProcessor; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.Configurable; +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; + +/** + * AMRules Regressor is the task for the serialized implementation of AMRules algorithm for regression rule. It is + * adapted to SAMOA from the implementation of AMRules in MOA. + * + * @author Anh Thu Vu + * + */ + +public class AMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 1L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 0.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[] { + "Adaptative", "Perceptron", "Target Mean" }, new String[] { + "Adaptative", "Perceptron", "Target Mean" }, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public ClassOption votingTypeOption = new ClassOption("votingType", + 'V', "Voting Type.", + ErrorWeightedVote.class, + "InverseErrorWeightedVote"); + + // Processor + private AMRulesRegressorProcessor processor; + + // Stream + private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + this.processor = new AMRulesRegressorProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue()) + .voteType((ErrorWeightedVote) votingTypeOption.getValue()) + .build(); + + topologyBuilder.addProcessor(processor, parallelism); + + this.resultStream = topologyBuilder.createStream(processor); + this.processor.setResultStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + return processor; + } + + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java new file mode 100644 index 0000000..822c2be --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java @@ -0,0 +1,240 @@ +package org.apache.samoa.learners.classifiers.rules; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.RegressionLearner; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRDefaultRuleProcessor; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRLearnerProcessor; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRRuleSetProcessor; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; + +/** + * Horizontal AMRules Regressor is a distributed learner for regression rules learner. It applies both horizontal + * parallelism (dividing incoming streams) and vertical parallelism on AMRules algorithm. + * + * @author Anh Thu Vu + * + */ +public class HorizontalAMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 2785944439173586051L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 0.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[] { + "Adaptative", "Perceptron", "Target Mean" }, new String[] { + "Adaptative", "Perceptron", "Target Mean" }, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public MultiChoiceOption votingTypeOption = new MultiChoiceOption( + "votingType", 'V', "Voting Type.", new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, 0); + + public IntOption learnerParallelismOption = new IntOption( + "leanerParallelism", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + public IntOption ruleSetParallelismOption = new IntOption( + "modelParallelism", + 'r', + "The number of replicated model (rule set) PIs", + 1, 1, Integer.MAX_VALUE); + + // Processor + private AMRRuleSetProcessor model; + + private Stream modelResultStream; + + private Stream rootResultStream; + + // private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + // Create MODEL PIs + this.model = new AMRRuleSetProcessor.Builder(dataset) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .voteType(votingTypeOption.getChosenIndex()) + .build(); + + topologyBuilder.addProcessor(model, this.ruleSetParallelismOption.getValue()); + + // MODEL PIs streams + Stream forwardToRootStream = topologyBuilder.createStream(this.model); + Stream forwardToLearnerStream = topologyBuilder.createStream(this.model); + this.modelResultStream = topologyBuilder.createStream(this.model); + + this.model.setDefaultRuleStream(forwardToRootStream); + this.model.setStatisticsStream(forwardToLearnerStream); + this.model.setResultStream(this.modelResultStream); + + // Create DefaultRule PI + AMRDefaultRuleProcessor root = new AMRDefaultRuleProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue()) + .build(); + + topologyBuilder.addProcessor(root); + + // Default Rule PI streams + Stream newRuleStream = topologyBuilder.createStream(root); + this.rootResultStream = topologyBuilder.createStream(root); + + root.setRuleStream(newRuleStream); + root.setResultStream(this.rootResultStream); + + // Create Learner PIs + AMRLearnerProcessor learner = new AMRLearnerProcessor.Builder(dataset) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .build(); + + topologyBuilder.addProcessor(learner, this.learnerParallelismOption.getValue()); + + Stream predicateStream = topologyBuilder.createStream(learner); + learner.setOutputStream(predicateStream); + + // Connect streams + // to MODEL + topologyBuilder.connectInputAllStream(newRuleStream, this.model); + topologyBuilder.connectInputAllStream(predicateStream, this.model); + // to ROOT + topologyBuilder.connectInputShuffleStream(forwardToRootStream, root); + // to LEARNER + topologyBuilder.connectInputKeyStream(forwardToLearnerStream, learner); + topologyBuilder.connectInputAllStream(newRuleStream, learner); + } + + @Override + public Processor getInputProcessor() { + return model; + } + + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.modelResultStream, this.rootResultStream); + return streams; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java new file mode 100644 index 0000000..2fb5c2d --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java @@ -0,0 +1,200 @@ +package org.apache.samoa.learners.classifiers.rules; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.RegressionLearner; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRulesAggregatorProcessor; +import org.apache.samoa.learners.classifiers.rules.distributed.AMRulesStatisticsProcessor; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.google.common.collect.ImmutableSet; + +/** + * Vertical AMRules Regressor is a distributed learner for regression rules learner. It applies vertical parallelism on + * AMRules regressor. + * + * @author Anh Thu Vu + * + */ + +public class VerticalAMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 2785944439173586051L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 00.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[] { + "Adaptative", "Perceptron", "Target Mean" }, new String[] { + "Adaptative", "Perceptron", "Target Mean" }, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public MultiChoiceOption votingTypeOption = new MultiChoiceOption( + "votingType", 'V', "Voting Type.", new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, new String[] { + "InverseErrorWeightedVote", "UniformWeightedVote" }, 0); + + public IntOption parallelismHintOption = new IntOption( + "parallelismHint", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + + // Processor + private AMRulesAggregatorProcessor aggregator; + + // Stream + private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + this.aggregator = new AMRulesAggregatorProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue()) + .voteType(votingTypeOption.getChosenIndex()) + .build(); + + topologyBuilder.addProcessor(aggregator); + + Stream statisticsStream = topologyBuilder.createStream(aggregator); + this.resultStream = topologyBuilder.createStream(aggregator); + + this.aggregator.setResultStream(resultStream); + this.aggregator.setStatisticsStream(statisticsStream); + + AMRulesStatisticsProcessor learner = new AMRulesStatisticsProcessor.Builder(dataset) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .build(); + + topologyBuilder.addProcessor(learner, this.parallelismHintOption.getValue()); + + topologyBuilder.connectInputKeyStream(statisticsStream, learner); + + Stream predicateStream = topologyBuilder.createStream(learner); + learner.setOutputStream(predicateStream); + + topologyBuilder.connectInputShuffleStream(predicateStream, aggregator); + } + + @Override + public Processor getInputProcessor() { + return aggregator; + } + + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +}
