http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java new file mode 100644 index 0000000..1d8db50 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/BoostingPredictionCombinerProcessor.java @@ -0,0 +1,176 @@ +package com.yahoo.labs.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.learners.InstanceContentEvent; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.moa.core.DoubleVector; +import com.yahoo.labs.samoa.moa.core.Utils; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * The Class BoostingPredictionCombinerProcessor. + */ +public class BoostingPredictionCombinerProcessor extends PredictionCombinerProcessor { + + private static final long serialVersionUID = -1606045723451191232L; + + //Weigths classifier + protected double[] scms; + + //Weights instance + protected double[] swms; + + /** + * On event. + * + * @param event the event + * @return true, if successful + */ + @Override + public boolean process(ContentEvent event) { + + ResultContentEvent inEvent = (ResultContentEvent) event; + double[] prediction = inEvent.getClassVotes(); + int instanceIndex = (int) inEvent.getInstanceIndex(); + + addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); + //Boosting + addPredictions(instanceIndex, inEvent, prediction); + + if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) { + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null){ + combinedVote = new DoubleVector(); + } + ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), + inEvent.getInstance(), inEvent.getClassId(), + combinedVote.getArrayCopy(), inEvent.isLastEvent()); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + outputStream.put(outContentEvent); + clearStatisticsInstance(instanceIndex); + //Boosting + computeBoosting(inEvent, instanceIndex); + return true; + } + return false; + + } + + protected Random random; + + protected int trainingWeightSeenByModel; + + @Override + protected double getEnsembleMemberWeight(int i) { + double em = this.swms[i] / (this.scms[i] + this.swms[i]); + if ((em == 0.0) || (em > 0.5)) { + return 0.0; + } + double Bm = em / (1.0 - em); + return Math.log(1.0 / Bm); + } + + @Override + public void reset() { + this.random = new Random(); + this.trainingWeightSeenByModel = 0; + this.scms = new double[this.ensembleSize]; + this.swms = new double[this.ensembleSize]; + } + + private boolean correctlyClassifies(int i, Instance inst, int instanceIndex) { + int predictedClass = (int) mapPredictions.get(instanceIndex).getValue(i); + return predictedClass == (int) inst.classValue(); + } + + protected Map<Integer, DoubleVector> mapPredictions; + + private void addPredictions(int instanceIndex, ResultContentEvent inEvent, double[] prediction) { + if (this.mapPredictions == null) { + this.mapPredictions = new HashMap<>(); + } + DoubleVector predictions = this.mapPredictions.get(instanceIndex); + if (predictions == null){ + predictions = new DoubleVector(); + } + predictions.setValue(inEvent.getClassifierIndex(), Utils.maxIndex(prediction)); + this.mapPredictions.put(instanceIndex, predictions); + } + + private void computeBoosting(ResultContentEvent inEvent, int instanceIndex) { + // Starts code for Boosting + //Send instances to train + double lambda_d = 1.0; + for (int i = 0; i < this.ensembleSize; i++) { + double k = lambda_d; + Instance inst = inEvent.getInstance(); + if (k > 0.0) { + Instance weightedInst = inst.copy(); + weightedInst.setWeight(inst.weight() * k); + //this.ensemble[i].trainOnInstance(weightedInst); + InstanceContentEvent instanceContentEvent = new InstanceContentEvent( + inEvent.getInstanceIndex(), weightedInst, true, false); + instanceContentEvent.setClassifierIndex(i); + instanceContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + trainingStream.put(instanceContentEvent); + } + if (this.correctlyClassifies(i, inst, instanceIndex)){ + this.scms[i] += lambda_d; + lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]); + } else { + this.swms[i] += lambda_d; + lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]); + } + } + } + + /** + * Gets the training stream. + * + * @return the training stream + */ + public Stream getTrainingStream() { + return trainingStream; + } + + /** + * Sets the training stream. + * + * @param trainingStream the new training stream + */ + public void setTrainingStream(Stream trainingStream) { + this.trainingStream = trainingStream; + } + + /** The training stream. */ + private Stream trainingStream; + +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java new file mode 100644 index 0000000..e4228d8 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/ensemble/PredictionCombinerProcessor.java @@ -0,0 +1,184 @@ +package com.yahoo.labs.samoa.learners.classifiers.ensemble; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ +import java.util.HashMap; +import java.util.Map; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.moa.core.DoubleVector; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * The Class PredictionCombinerProcessor. + */ +public class PredictionCombinerProcessor implements Processor { + + private static final long serialVersionUID = -1606045723451191132L; + + /** + * The size ensemble. + */ + protected int ensembleSize; + + /** + * The output stream. + */ + protected Stream outputStream; + + /** + * Sets the output stream. + * + * @param stream the new output stream + */ + public void setOutputStream(Stream stream) { + outputStream = stream; + } + + /** + * Gets the output stream. + * + * @return the output stream + */ + public Stream getOutputStream() { + return outputStream; + } + + /** + * Gets the size ensemble. + * + * @return the ensembleSize + */ + public int getSizeEnsemble() { + return ensembleSize; + } + + /** + * Sets the size ensemble. + * + * @param ensembleSize the new size ensemble + */ + public void setSizeEnsemble(int ensembleSize) { + this.ensembleSize = ensembleSize; + } + + protected Map<Integer, Integer> mapCountsforInstanceReceived; + + protected Map<Integer, DoubleVector> mapVotesforInstanceReceived; + + /** + * On event. + * + * @param event the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + + ResultContentEvent inEvent = (ResultContentEvent) event; + double[] prediction = inEvent.getClassVotes(); + int instanceIndex = (int) inEvent.getInstanceIndex(); + + addStatisticsForInstanceReceived(instanceIndex, inEvent.getClassifierIndex(), prediction, 1); + + if (inEvent.isLastEvent() || hasAllVotesArrivedInstance(instanceIndex)) { + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null){ + combinedVote = new DoubleVector(new double[inEvent.getInstance().numClasses()]); + } + ResultContentEvent outContentEvent = new ResultContentEvent(inEvent.getInstanceIndex(), + inEvent.getInstance(), inEvent.getClassId(), + combinedVote.getArrayCopy(), inEvent.isLastEvent()); + outContentEvent.setEvaluationIndex(inEvent.getEvaluationIndex()); + outputStream.put(outContentEvent); + clearStatisticsInstance(instanceIndex); + return true; + } + return false; + + } + + @Override + public void onCreate(int id) { + this.reset(); + } + + public void reset() { + } + + + /* (non-Javadoc) + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + PredictionCombinerProcessor newProcessor = new PredictionCombinerProcessor(); + PredictionCombinerProcessor originProcessor = (PredictionCombinerProcessor) sourceProcessor; + if (originProcessor.getOutputStream() != null) { + newProcessor.setOutputStream(originProcessor.getOutputStream()); + } + newProcessor.setSizeEnsemble(originProcessor.getSizeEnsemble()); + return newProcessor; + } + + protected void addStatisticsForInstanceReceived(int instanceIndex, int classifierIndex, double[] prediction, int add) { + if (this.mapCountsforInstanceReceived == null) { + this.mapCountsforInstanceReceived = new HashMap<>(); + this.mapVotesforInstanceReceived = new HashMap<>(); + } + DoubleVector vote = new DoubleVector(prediction); + if (vote.sumOfValues() > 0.0) { + vote.normalize(); + DoubleVector combinedVote = this.mapVotesforInstanceReceived.get(instanceIndex); + if (combinedVote == null){ + combinedVote = new DoubleVector(); + } + vote.scaleValues(getEnsembleMemberWeight(classifierIndex)); + combinedVote.addValues(vote); + + this.mapVotesforInstanceReceived.put(instanceIndex, combinedVote); + } + Integer count = this.mapCountsforInstanceReceived.get(instanceIndex); + if (count == null) { + count = 0; + } + this.mapCountsforInstanceReceived.put(instanceIndex, count + add); + } + + protected boolean hasAllVotesArrivedInstance(int instanceIndex) { + return (this.mapCountsforInstanceReceived.get(instanceIndex) == this.ensembleSize); + } + + protected void clearStatisticsInstance(int instanceIndex) { + this.mapCountsforInstanceReceived.remove(instanceIndex); + this.mapVotesforInstanceReceived.remove(instanceIndex); + } + + protected double getEnsembleMemberWeight(int i) { + return 1.0; + } + + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java new file mode 100644 index 0000000..268072b --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/AMRulesRegressor.java @@ -0,0 +1,176 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +import com.github.javacliparser.Configurable; +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.RegressionLearner; +import com.yahoo.labs.samoa.learners.classifiers.rules.centralized.AMRulesRegressorProcessor; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import com.yahoo.labs.samoa.topology.Stream; +import com.yahoo.labs.samoa.topology.TopologyBuilder; + +/** + * AMRules Regressor + * is the task for the serialized implementation of AMRules algorithm for regression rule. + * It is adapted to SAMOA from the implementation of AMRules in MOA. + * + * @author Anh Thu Vu + * + */ + +public class AMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 1L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 0.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[]{ + "Adaptative","Perceptron", "Target Mean"}, new String[]{ + "Adaptative","Perceptron", "Target Mean"}, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public ClassOption votingTypeOption = new ClassOption("votingType", + 'V', "Voting Type.", + ErrorWeightedVote.class, + "InverseErrorWeightedVote"); + + // Processor + private AMRulesRegressorProcessor processor; + + // Stream + private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + this.processor = new AMRulesRegressorProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver)numericObserverOption.getValue()) + .voteType((ErrorWeightedVote)votingTypeOption.getValue()) + .build(); + + topologyBuilder.addProcessor(processor, parallelism); + + this.resultStream = topologyBuilder.createStream(processor); + this.processor.setResultStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + return processor; + } + + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java new file mode 100644 index 0000000..14f5f38 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/HorizontalAMRulesRegressor.java @@ -0,0 +1,240 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.RegressionLearner; +import com.yahoo.labs.samoa.learners.classifiers.rules.distributed.AMRDefaultRuleProcessor; +import com.yahoo.labs.samoa.learners.classifiers.rules.distributed.AMRLearnerProcessor; +import com.yahoo.labs.samoa.learners.classifiers.rules.distributed.AMRRuleSetProcessor; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.topology.Stream; +import com.yahoo.labs.samoa.topology.TopologyBuilder; + +/** + * Horizontal AMRules Regressor + * is a distributed learner for regression rules learner. + * It applies both horizontal parallelism (dividing incoming streams) + * and vertical parallelism on AMRules algorithm. + * + * @author Anh Thu Vu + * + */ +public class HorizontalAMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 2785944439173586051L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 0.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[]{ + "Adaptative","Perceptron", "Target Mean"}, new String[]{ + "Adaptative","Perceptron", "Target Mean"}, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public MultiChoiceOption votingTypeOption = new MultiChoiceOption( + "votingType", 'V', "Voting Type.", new String[]{ + "InverseErrorWeightedVote","UniformWeightedVote"}, new String[]{ + "InverseErrorWeightedVote","UniformWeightedVote"}, 0); + + public IntOption learnerParallelismOption = new IntOption( + "leanerParallelism", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + public IntOption ruleSetParallelismOption = new IntOption( + "modelParallelism", + 'r', + "The number of replicated model (rule set) PIs", + 1, 1, Integer.MAX_VALUE); + + // Processor + private AMRRuleSetProcessor model; + + private Stream modelResultStream; + + private Stream rootResultStream; + + // private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + // Create MODEL PIs + this.model = new AMRRuleSetProcessor.Builder(dataset) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .voteType(votingTypeOption.getChosenIndex()) + .build(); + + topologyBuilder.addProcessor(model, this.ruleSetParallelismOption.getValue()); + + // MODEL PIs streams + Stream forwardToRootStream = topologyBuilder.createStream(this.model); + Stream forwardToLearnerStream = topologyBuilder.createStream(this.model); + this.modelResultStream = topologyBuilder.createStream(this.model); + + this.model.setDefaultRuleStream(forwardToRootStream); + this.model.setStatisticsStream(forwardToLearnerStream); + this.model.setResultStream(this.modelResultStream); + + // Create DefaultRule PI + AMRDefaultRuleProcessor root = new AMRDefaultRuleProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver) numericObserverOption.getValue()) + .build(); + + topologyBuilder.addProcessor(root); + + // Default Rule PI streams + Stream newRuleStream = topologyBuilder.createStream(root); + this.rootResultStream = topologyBuilder.createStream(root); + + root.setRuleStream(newRuleStream); + root.setResultStream(this.rootResultStream); + + // Create Learner PIs + AMRLearnerProcessor learner = new AMRLearnerProcessor.Builder(dataset) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .build(); + + topologyBuilder.addProcessor(learner, this.learnerParallelismOption.getValue()); + + Stream predicateStream = topologyBuilder.createStream(learner); + learner.setOutputStream(predicateStream); + + // Connect streams + // to MODEL + topologyBuilder.connectInputAllStream(newRuleStream, this.model); + topologyBuilder.connectInputAllStream(predicateStream, this.model); + // to ROOT + topologyBuilder.connectInputShuffleStream(forwardToRootStream, root); + // to LEARNER + topologyBuilder.connectInputKeyStream(forwardToLearnerStream, learner); + topologyBuilder.connectInputAllStream(newRuleStream, learner); + } + + @Override + public Processor getInputProcessor() { + return model; + } + + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.modelResultStream,this.rootResultStream); + return streams; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java new file mode 100644 index 0000000..597becb --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/VerticalAMRulesRegressor.java @@ -0,0 +1,200 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +import java.util.Set; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.github.javacliparser.MultiChoiceOption; +import com.google.common.collect.ImmutableSet; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.RegressionLearner; +import com.yahoo.labs.samoa.learners.classifiers.rules.distributed.AMRulesAggregatorProcessor; +import com.yahoo.labs.samoa.learners.classifiers.rules.distributed.AMRulesStatisticsProcessor; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.topology.Stream; +import com.yahoo.labs.samoa.topology.TopologyBuilder; + +/** + * Vertical AMRules Regressor + * is a distributed learner for regression rules learner. + * It applies vertical parallelism on AMRules regressor. + * + * @author Anh Thu Vu + * + */ + +public class VerticalAMRulesRegressor implements RegressionLearner, Configurable { + + /** + * + */ + private static final long serialVersionUID = 2785944439173586051L; + + // Options + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "Hoeffding Bound Parameter. The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Hoeffding Bound Parameter. Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption("gracePeriod", + 'g', "Hoeffding Bound Parameter. The number of instances a leaf should observe between split attempts.", + 200, 1, Integer.MAX_VALUE); + + public FlagOption DriftDetectionOption = new FlagOption("DoNotDetectChanges", 'H', + "Drift Detection. Page-Hinkley."); + + public FloatOption pageHinckleyAlphaOption = new FloatOption( + "pageHinckleyAlpha", + 'a', + "The alpha value to use in the Page Hinckley change detection tests.", + 00.005, 0.0, 1.0); + + public IntOption pageHinckleyThresholdOption = new IntOption( + "pageHinckleyThreshold", + 'l', + "The threshold value (Lambda) to be used in the Page Hinckley change detection tests.", + 35, 0, Integer.MAX_VALUE); + + public FlagOption noAnomalyDetectionOption = new FlagOption("noAnomalyDetection", 'A', + "Disable anomaly Detection."); + + public FloatOption multivariateAnomalyProbabilityThresholdOption = new FloatOption( + "multivariateAnomalyProbabilityThresholdd", + 'm', + "Multivariate anomaly threshold value.", + 0.99, 0.0, 1.0); + + public FloatOption univariateAnomalyProbabilityThresholdOption = new FloatOption( + "univariateAnomalyprobabilityThreshold", + 'u', + "Univariate anomaly threshold value.", + 0.10, 0.0, 1.0); + + public IntOption anomalyNumInstThresholdOption = new IntOption( + "anomalyThreshold", + 'n', + "The threshold value of anomalies to be used in the anomaly detection.", + 30, 0, Integer.MAX_VALUE); // num minimum of instances to detect anomalies. 15. + + public FlagOption unorderedRulesOption = new FlagOption("setUnorderedRulesOn", 'U', + "unorderedRules."); + + public ClassOption numericObserverOption = new ClassOption("numericObserver", + 'z', "Numeric observer.", + FIMTDDNumericAttributeClassLimitObserver.class, + "FIMTDDNumericAttributeClassLimitObserver"); + + public MultiChoiceOption predictionFunctionOption = new MultiChoiceOption( + "predictionFunctionOption", 'P', "The prediction function to use.", new String[]{ + "Adaptative","Perceptron", "Target Mean"}, new String[]{ + "Adaptative","Perceptron", "Target Mean"}, 0); + + public FlagOption constantLearningRatioDecayOption = new FlagOption( + "learningRatio_Decay_set_constant", 'd', + "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + + public FloatOption learningRatioOption = new FloatOption( + "learningRatio", 's', + "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.025); + + public MultiChoiceOption votingTypeOption = new MultiChoiceOption( + "votingType", 'V', "Voting Type.", new String[]{ + "InverseErrorWeightedVote","UniformWeightedVote"}, new String[]{ + "InverseErrorWeightedVote","UniformWeightedVote"}, 0); + + public IntOption parallelismHintOption = new IntOption( + "parallelismHint", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + + // Processor + private AMRulesAggregatorProcessor aggregator; + + // Stream + private Stream resultStream; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + this.aggregator = new AMRulesAggregatorProcessor.Builder(dataset) + .threshold(pageHinckleyThresholdOption.getValue()) + .alpha(pageHinckleyAlphaOption.getValue()) + .changeDetection(this.DriftDetectionOption.isSet()) + .predictionFunction(predictionFunctionOption.getChosenIndex()) + .constantLearningRatioDecay(constantLearningRatioDecayOption.isSet()) + .learningRatio(learningRatioOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .noAnomalyDetection(noAnomalyDetectionOption.isSet()) + .multivariateAnomalyProbabilityThreshold(multivariateAnomalyProbabilityThresholdOption.getValue()) + .univariateAnomalyProbabilityThreshold(univariateAnomalyProbabilityThresholdOption.getValue()) + .anomalyNumberOfInstancesThreshold(anomalyNumInstThresholdOption.getValue()) + .unorderedRules(unorderedRulesOption.isSet()) + .numericObserver((FIMTDDNumericAttributeClassLimitObserver)numericObserverOption.getValue()) + .voteType(votingTypeOption.getChosenIndex()) + .build(); + + topologyBuilder.addProcessor(aggregator); + + Stream statisticsStream = topologyBuilder.createStream(aggregator); + this.resultStream = topologyBuilder.createStream(aggregator); + + this.aggregator.setResultStream(resultStream); + this.aggregator.setStatisticsStream(statisticsStream); + + AMRulesStatisticsProcessor learner = new AMRulesStatisticsProcessor.Builder(dataset) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .build(); + + topologyBuilder.addProcessor(learner, this.parallelismHintOption.getValue()); + + topologyBuilder.connectInputKeyStream(statisticsStream, learner); + + Stream predicateStream = topologyBuilder.createStream(learner); + learner.setOutputStream(predicateStream); + + topologyBuilder.connectInputShuffleStream(predicateStream, aggregator); + } + + @Override + public Processor getInputProcessor() { + return aggregator; + } + + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java new file mode 100644 index 0000000..f83d6fd --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/centralized/AMRulesRegressorProcessor.java @@ -0,0 +1,509 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.centralized; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.InstanceContentEvent; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.Perceptron; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * AMRules Regressor Processor + * is the main (and only) processor for AMRulesRegressor task. + * It is adapted from the AMRules implementation in MOA. + * + * @author Anh Thu Vu + * + */ +public class AMRulesRegressorProcessor implements Processor { + /** + * + */ + private static final long serialVersionUID = 1L; + + private int processorId; + + // Rules & default rule + protected List<ActiveRule> ruleSet; + protected ActiveRule defaultRule; + protected int ruleNumberID; + protected double[] statistics; + + // SAMOA Stream + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + + protected ErrorWeightedVote voteType; + + /* + * Constructor + */ + public AMRulesRegressorProcessor (Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.numericObserver = builder.numericObserver; + this.voteType = builder.voteType; + } + + /* + * Process + */ + @Override + public boolean process(ContentEvent event) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + + // predict + if (instanceEvent.isTesting()) { + this.predictOnInstance(instanceEvent); + } + + // train + if (instanceEvent.isTraining()) { + this.trainOnInstance(instanceEvent); + } + + return true; + } + + /* + * Prediction + */ + private void predictOnInstance (InstanceContentEvent instanceEvent) { + double[] prediction = getVotesForInstance(instanceEvent.getInstance()); + ResultContentEvent rce = newResultContentEvent(prediction, instanceEvent); + resultStream.put(rce); + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and + * its prediction result. + * @param prediction The predicted class label from the decision tree model. + * @param inEvent The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){ + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + /** + * getVotesForInstance extension of the instance method getVotesForInstance + * in moa.classifier.java + * returns the prediction of the instance. + * Called in EvaluateModelRegression + */ + private double[] getVotesForInstance(Instance instance) { + ErrorWeightedVote errorWeightedVote=newErrorWeightedVote(); + int numberOfRulesCovering = 0; + + for (ActiveRule rule: ruleSet) { + if (rule.isCovering(instance) == true){ + numberOfRulesCovering++; + double [] vote=rule.getPrediction(instance); + double error= rule.getCurrentError(); + errorWeightedVote.addVote(vote,error); + if (!this.unorderedRules) { // Ordered Rules Option. + break; // Only one rule cover the instance. + } + } + } + + if (numberOfRulesCovering == 0) { + double [] vote=defaultRule.getPrediction(instance); + double error= defaultRule.getCurrentError(); + errorWeightedVote.addVote(vote,error); + } + double[] weightedVote=errorWeightedVote.computeWeightedVote(); + + return weightedVote; + } + + public ErrorWeightedVote newErrorWeightedVote() { + return voteType.getACopy(); + } + + /* + * Training + */ + private void trainOnInstance (InstanceContentEvent instanceEvent) { + this.trainOnInstanceImpl(instanceEvent.getInstance()); + } + public void trainOnInstanceImpl(Instance instance) { + /** + * AMRules Algorithm + * + //For each rule in the rule set + //If rule covers the instance + //if the instance is not an anomaly + //Update Change Detection Tests + //Compute prediction error + //Call PHTest + //If change is detected then + //Remove rule + //Else + //Update sufficient statistics of rule + //If number of examples in rule > Nmin + //Expand rule + //If ordered set then + //break + //If none of the rule covers the instance + //Update sufficient statistics of default rule + //If number of examples in default rule is multiple of Nmin + //Expand default rule and add it to the set of rules + //Reset the default rule + */ + boolean rulesCoveringInstance = false; + Iterator<ActiveRule> ruleIterator= this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.isCovering(instance) == true) { + rulesCoveringInstance = true; + if (isAnomaly(instance, rule) == false) { + //Update Change Detection Tests + double error = rule.computeError(instance); //Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode)rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold) ) { + rule.split(); + } + } + } + if (!this.unorderedRules) + break; + } + } + } + + if (rulesCoveringInstance == false){ + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule=newRule(defaultRule.getRuleNumberID(),(RuleActiveRegressionNode)defaultRule.getLearningNode(), + ((RuleActiveRegressionNode)defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); //other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + this.ruleSet.add(this.defaultRule); + + defaultRule=newDefaultRule; + + } + } + } + } + + /** + * Method to verify if the instance is an anomaly. + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, ActiveRule rule) { + //AMRUles is equipped with anomaly detection. If on, compute the anomaly value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false){ + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Create new rules + */ + // TODO check this after finish rule, LN + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r=newRule(ID); + + if (node!=null) + { + if(node.getPerceptron()!=null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics==null) + { + double mean; + if(node.getNodeStatistics().getValue(0)>0){ + mean=node.getNodeStatistics().getValue(1)/node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics!=null && ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean()!=null) + { + double mean; + if(statistics[0]>0){ + mean=statistics[1]/statistics[0]; + ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean().reset(mean, (long)statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r=new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + /* + * Init processor + */ + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics= new double[]{0.0,0,0}; + this.ruleNumberID=0; + this.defaultRule = newRule(++this.ruleNumberID); + + this.ruleSet = new LinkedList<ActiveRule>(); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRulesRegressorProcessor oldProcessor = (AMRulesRegressorProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRulesRegressorProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + return newProcessor; + } + + /* + * Output stream + */ + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Others + */ + public boolean isRandomizable() { + return true; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private ErrorWeightedVote voteType; + + private Instances dataset; + + public Builder(Instances dataset){ + this.dataset = dataset; + } + + public Builder(AMRulesRegressorProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.numericObserver = processor.numericObserver; + this.voteType = processor.voteType; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder voteType(ErrorWeightedVote voteType) { + this.voteType = voteType; + return this; + } + + public AMRulesRegressorProcessor build() { + return new AMRulesRegressorProcessor(this); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java new file mode 100644 index 0000000..b6fba99 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/ActiveRule.java @@ -0,0 +1,226 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate; + +/** + * ActiveRule is a LearningRule that actively update its LearningNode + * with incoming instances. + * + * @author Anh Thu Vu + * + */ + +public class ActiveRule extends LearningRule { + + private static final long serialVersionUID = 1L; + + private double[] statisticsOtherBranchSplit; + + private Builder builder; + + private RuleActiveRegressionNode learningNode; + + private RuleSplitNode lastUpdatedRuleSplitNode; + + /* + * Constructor with Builder + */ + public ActiveRule() { + super(); + this.builder = null; + this.learningNode = null; + this.ruleNumberID = 0; + } + public ActiveRule(Builder builder) { + super(); + this.setBuilder(builder); + this.learningNode = newRuleActiveLearningNode(builder); + //JD - use builder ID to set ruleNumberID + this.ruleNumberID=builder.id; + } + + private RuleActiveRegressionNode newRuleActiveLearningNode(Builder builder) { + return new RuleActiveRegressionNode(builder); + } + + /* + * Setters & getters + */ + public Builder getBuilder() { + return builder; + } + + public void setBuilder(Builder builder) { + this.builder = builder; + } + + @Override + public RuleRegressionNode getLearningNode() { + return this.learningNode; + } + + @Override + public void setLearningNode(RuleRegressionNode learningNode) { + this.learningNode = (RuleActiveRegressionNode) learningNode; + } + + public double[] statisticsOtherBranchSplit() { + return this.statisticsOtherBranchSplit; + } + + public RuleSplitNode getLastUpdatedRuleSplitNode() { + return this.lastUpdatedRuleSplitNode; + } + + /* + * Builder + */ + public static class Builder implements Serializable { + + private static final long serialVersionUID = 1712887264918475622L; + protected boolean changeDetection; + protected boolean usePerceptron; + protected double threshold; + protected double alpha; + protected int predictionFunction; + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double[] statistics; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + + protected double lastTargetMean; + + public int id; + + public Builder() { + } + + public Builder changeDetection(boolean changeDetection) { + this.changeDetection = changeDetection; + return this; + } + + public Builder threshold(double threshold) { + this.threshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.alpha = alpha; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder statistics(double[] statistics) { + this.statistics = statistics; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantLearningRatioDecay) { + this.constantLearningRatioDecay = constantLearningRatioDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder id(int id) { + this.id = id; + return this; + } + public ActiveRule build() { + return new ActiveRule(this); + } + + } + + /** + * Try to Expand method. + * @param splitConfidence + * @param tieThreshold + * @return + */ + public boolean tryToExpand(double splitConfidence, double tieThreshold) { + + boolean shouldSplit= this.learningNode.tryToExpand(splitConfidence, tieThreshold); + return shouldSplit; + + } + + //JD: Only call after tryToExpand returning true + public void split() + { + //this.statisticsOtherBranchSplit = this.learningNode.getStatisticsOtherBranchSplit(); + //create a split node, + int splitIndex = this.learningNode.getSplitIndex(); + InstanceConditionalTest st=this.learningNode.getBestSuggestion().splitTest; + if(st instanceof NumericAttributeBinaryTest) { + NumericAttributeBinaryTest splitTest = (NumericAttributeBinaryTest) st; + NumericAttributeBinaryRulePredicate predicate = new NumericAttributeBinaryRulePredicate( + splitTest.getAttsTestDependsOn()[0], splitTest.getSplitValue(), + splitIndex + 1); + lastUpdatedRuleSplitNode = new RuleSplitNode(predicate, this.learningNode.getStatisticsBranchSplit() ); + if (this.nodeListAdd(lastUpdatedRuleSplitNode)) { + // create a new learning node + RuleActiveRegressionNode newLearningNode = newRuleActiveLearningNode(this.getBuilder().statistics(this.learningNode.getStatisticsNewRuleActiveLearningNode())); + newLearningNode.initialize(this.learningNode); + this.learningNode = newLearningNode; + } + } + else + throw new UnsupportedOperationException("AMRules (currently) only supports numerical attributes."); + } + + + +// protected void debug(String string, int level) { +// if (this.amRules.VerbosityOption.getValue()>=level) { +// System.out.println(string); +// } +//} + + /** + * MOA GUI output + */ + @Override + public void getDescription(StringBuilder sb, int indent) { + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java new file mode 100644 index 0000000..4c05632 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/LearningRule.java @@ -0,0 +1,122 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.core.DoubleVector; +import com.yahoo.labs.samoa.moa.core.StringUtils; + +/** + * Rule with LearningNode (statistical data). + * + * @author Anh Thu Vu + * + */ +public abstract class LearningRule extends Rule { + + /** + * + */ + private static final long serialVersionUID = 1L; + + /* + * Constructor + */ + public LearningRule() { + super(); + } + + /* + * LearningNode + */ + public abstract RuleRegressionNode getLearningNode(); + + public abstract void setLearningNode(RuleRegressionNode learningNode); + + /* + * No. of instances seen + */ + public long getInstancesSeen() { + return this.getLearningNode().getInstancesSeen(); + } + + /* + * Error and change detection + */ + public double computeError(Instance instance) { + return this.getLearningNode().computeError(instance); + } + + + /* + * Prediction + */ + public double[] getPrediction(Instance instance, int mode) { + return this.getLearningNode().getPrediction(instance, mode); + } + + public double[] getPrediction(Instance instance) { + return this.getLearningNode().getPrediction(instance); + } + + public double getCurrentError() { + return this.getLearningNode().getCurrentError(); + } + + /* + * Anomaly detection + */ + public boolean isAnomaly(Instance instance, + double uniVariateAnomalyProbabilityThreshold, + double multiVariateAnomalyProbabilityThreshold, + int numberOfInstanceesForAnomaly) { + return this.getLearningNode().isAnomaly(instance, uniVariateAnomalyProbabilityThreshold, + multiVariateAnomalyProbabilityThreshold, + numberOfInstanceesForAnomaly); + } + + /* + * Update + */ + public void updateStatistics(Instance instance) { + this.getLearningNode().updateStatistics(instance); + } + + public String printRule() { + StringBuilder out = new StringBuilder(); + int indent = 1; + StringUtils.appendIndented(out, indent, "Rule Nr." + this.ruleNumberID + " Instances seen:" + this.getLearningNode().getInstancesSeen() + "\n"); // AC + for (RuleSplitNode node : nodeList) { + StringUtils.appendIndented(out, indent, node.getSplitTest().toString()); + StringUtils.appendIndented(out, indent, " "); + StringUtils.appendIndented(out, indent, node.toString()); + } + DoubleVector pred = new DoubleVector(this.getLearningNode().getSimplePrediction()); + StringUtils.appendIndented(out, 0, " --> y: " + pred.toString()); + StringUtils.appendNewline(out); + + if(getLearningNode().perceptron!=null){ + ((RuleActiveRegressionNode)this.getLearningNode()).perceptron.getModelDescription(out,0); + StringUtils.appendNewline(out); + } + return(out.toString()); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java new file mode 100644 index 0000000..df5b9f9 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/NonLearningRule.java @@ -0,0 +1,51 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * The most basic rule: inherit from Rule the ID and list of features. + * + * @author Anh Thu Vu + * + */ +/* + * This branch (Non-learning rule) was created for an old implementation. + * Probably should remove None-Learning and Learning Rule classes, + * merge Rule with LearningRule. + */ +public class NonLearningRule extends Rule { + + /** + * + */ + private static final long serialVersionUID = -1210907339230307784L; + + public NonLearningRule(ActiveRule rule) { + this.nodeList = rule.nodeList; + this.ruleNumberID = rule.ruleNumberID; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // do nothing + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java new file mode 100644 index 0000000..8281d45 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/PassiveRule.java @@ -0,0 +1,71 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.LinkedList; + +/** + * PassiveRule is a LearningRule that update its LearningNode + * with the received new LearningNode. + * + * @author Anh Thu Vu + * + */ +public class PassiveRule extends LearningRule { + + /** + * + */ + private static final long serialVersionUID = -5551571895910530275L; + + private RulePassiveRegressionNode learningNode; + + /* + * Constructor to turn an ActiveRule into a PassiveRule + */ + public PassiveRule(ActiveRule rule) { + this.nodeList = new LinkedList<>(); + for (RuleSplitNode node:rule.nodeList) { + this.nodeList.add(node.getACopy()); + } + + this.learningNode = new RulePassiveRegressionNode(rule.getLearningNode()); + this.ruleNumberID = rule.ruleNumberID; + } + + @Override + public RuleRegressionNode getLearningNode() { + return this.learningNode; + } + + @Override + public void setLearningNode(RuleRegressionNode learningNode) { + this.learningNode = (RulePassiveRegressionNode) learningNode; + } + + /* + * MOA GUI + */ + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java new file mode 100644 index 0000000..53583ed --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Perceptron.java @@ -0,0 +1,487 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.classifiers.AbstractClassifier; +import com.yahoo.labs.samoa.moa.classifiers.Regressor; +import com.yahoo.labs.samoa.moa.core.DoubleVector; +import com.yahoo.labs.samoa.moa.core.Measurement; + +/** + * Prediction scheme using Perceptron: + * Predictions are computed according to a linear function of the attributes. + * + * @author Anh Thu Vu + * + */ +public class Perceptron extends AbstractClassifier implements Regressor { + + private final double SD_THRESHOLD = 0.0000001; //THRESHOLD for normalizing attribute and target values + + private static final long serialVersionUID = 1L; + + // public FlagOption constantLearningRatioDecayOption = new FlagOption( + // "learningRatio_Decay_set_constant", 'd', + // "Learning Ratio Decay in Perceptron set to be constant. (The next parameter)."); + // + // public FloatOption learningRatioOption = new FloatOption( + // "learningRatio", 'l', + // "Constante Learning Ratio to use for training the Perceptrons in the leaves.", 0.01); + // + // public FloatOption learningRateDecayOption = new FloatOption( + // "learningRateDecay", 'm', + // " Learning Rate decay to use for training the Perceptron.", 0.001); + // + // public FloatOption fadingFactorOption = new FloatOption( + // "fadingFactor", 'e', + // "Fading factor for the Perceptron accumulated error", 0.99, 0, 1); + + protected boolean constantLearningRatioDecay; + protected double originalLearningRatio; + + private double nError; + protected double fadingFactor = 0.99; + private double learningRatio; + protected double learningRateDecay = 0.001; + + // The Perception weights + protected double[] weightAttribute; + + // Statistics used for error calculations + public DoubleVector perceptronattributeStatistics = new DoubleVector(); + public DoubleVector squaredperceptronattributeStatistics = new DoubleVector(); + + // The number of instances contributing to this model + protected int perceptronInstancesSeen; + protected int perceptronYSeen; + + protected double accumulatedError; + + // If the model (weights) should be reset or not + protected boolean initialisePerceptron; + + protected double perceptronsumY; + protected double squaredperceptronsumY; + + + public Perceptron() { + this.initialisePerceptron = true; + } + + /* + * Perceptron + */ + public Perceptron(Perceptron p) { + this(p,false); + } + + public Perceptron(Perceptron p, boolean copyAccumulatedError) { + super(); + // this.constantLearningRatioDecayOption = p.constantLearningRatioDecayOption; + // this.learningRatioOption = p.learningRatioOption; + // this.learningRateDecayOption=p.learningRateDecayOption; + // this.fadingFactorOption = p.fadingFactorOption; + this.constantLearningRatioDecay = p.constantLearningRatioDecay; + this.originalLearningRatio = p.originalLearningRatio; + if (copyAccumulatedError) + this.accumulatedError = p.accumulatedError; + this.nError = p.nError; + this.fadingFactor = p.fadingFactor; + this.learningRatio = p.learningRatio; + this.learningRateDecay = p.learningRateDecay; + if (p.weightAttribute!=null) + this.weightAttribute = p.weightAttribute.clone(); + + this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics); + this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics); + this.perceptronInstancesSeen = p.perceptronInstancesSeen; + + this.initialisePerceptron = p.initialisePerceptron; + this.perceptronsumY = p.perceptronsumY; + this.squaredperceptronsumY = p.squaredperceptronsumY; + this.perceptronYSeen=p.perceptronYSeen; + } + + public Perceptron(PerceptronData p) { + super(); + this.constantLearningRatioDecay = p.constantLearningRatioDecay; + this.originalLearningRatio = p.originalLearningRatio; + this.nError = p.nError; + this.fadingFactor = p.fadingFactor; + this.learningRatio = p.learningRatio; + this.learningRateDecay = p.learningRateDecay; + if (p.weightAttribute!=null) + this.weightAttribute = p.weightAttribute.clone(); + + this.perceptronattributeStatistics = new DoubleVector(p.perceptronattributeStatistics); + this.squaredperceptronattributeStatistics = new DoubleVector(p.squaredperceptronattributeStatistics); + this.perceptronInstancesSeen = p.perceptronInstancesSeen; + + this.initialisePerceptron = p.initialisePerceptron; + this.perceptronsumY = p.perceptronsumY; + this.squaredperceptronsumY = p.squaredperceptronsumY; + this.perceptronYSeen=p.perceptronYSeen; + this.accumulatedError = p.accumulatedError; + } + + // private void printPerceptron() { + // System.out.println("Learning Ratio:"+this.learningRatio+" ("+this.originalLearningRatio+")"); + // System.out.println("Constant Learning Ratio Decay:"+this.constantLearningRatioDecay+" ("+this.learningRateDecay+")"); + // System.out.println("Error:"+this.accumulatedError+"/"+this.nError); + // System.out.println("Fading factor:"+this.fadingFactor); + // System.out.println("Perceptron Y:"+this.perceptronsumY+"/"+this.squaredperceptronsumY+"/"+this.perceptronYSeen); + // } + + /* + * Weights + */ + public void setWeights(double[] w) { + this.weightAttribute = w; + } + + public double[] getWeights() { + return this.weightAttribute; + } + + /* + * No. of instances seen + */ + public int getInstancesSeen() { + return perceptronInstancesSeen; + } + + public void setInstancesSeen(int pInstancesSeen) { + this.perceptronInstancesSeen = pInstancesSeen; + } + + /** + * A method to reset the model + */ + public void resetLearningImpl() { + this.initialisePerceptron = true; + this.reset(); + } + + public void reset(){ + this.nError=0.0; + this.accumulatedError = 0.0; + this.perceptronInstancesSeen = 0; + this.perceptronattributeStatistics = new DoubleVector(); + this.squaredperceptronattributeStatistics = new DoubleVector(); + this.perceptronsumY = 0.0; + this.squaredperceptronsumY = 0.0; + this.perceptronYSeen=0; + } + + public void resetError(){ + this.nError=0.0; + this.accumulatedError = 0.0; + } + + /** + * Update the model using the provided instance + */ + public void trainOnInstanceImpl(Instance inst) { + accumulatedError= Math.abs(this.prediction(inst)-inst.classValue()) + fadingFactor*accumulatedError; + nError=1+fadingFactor*nError; + // Initialise Perceptron if necessary + if (this.initialisePerceptron) { + //this.fadingFactor=this.fadingFactorOption.getValue(); + //this.classifierRandom.setSeed(randomSeedOption.getValue()); + this.classifierRandom.setSeed(randomSeed); + this.initialisePerceptron = false; // not in resetLearningImpl() because it needs Instance! + this.weightAttribute = new double[inst.numAttributes()]; + for (int j = 0; j < inst.numAttributes(); j++) { + weightAttribute[j] = 2 * this.classifierRandom.nextDouble() - 1; + } + // Update Learning Rate + learningRatio = originalLearningRatio; + //this.learningRateDecay = learningRateDecayOption.getValue(); + + } + + // Update attribute statistics + this.perceptronInstancesSeen++; + this.perceptronYSeen++; + + + for(int j = 0; j < inst.numAttributes() -1; j++) + { + perceptronattributeStatistics.addToValue(j, inst.value(j)); + squaredperceptronattributeStatistics.addToValue(j, inst.value(j)*inst.value(j)); + } + this.perceptronsumY += inst.classValue(); + this.squaredperceptronsumY += inst.classValue() * inst.classValue(); + + if(!constantLearningRatioDecay){ + learningRatio = originalLearningRatio / (1+ perceptronInstancesSeen*learningRateDecay); + } + + this.updateWeights(inst,learningRatio); + //this.printPerceptron(); + } + + /** + * Output the prediction made by this perceptron on the given instance + */ + private double prediction(Instance inst) + { + double[] normalizedInstance = normalizedInstance(inst); + double normalizedPrediction = prediction(normalizedInstance); + return denormalizedPrediction(normalizedPrediction); + } + + public double normalizedPrediction(Instance inst) + { + double[] normalizedInstance = normalizedInstance(inst); + return prediction(normalizedInstance); + } + + private double denormalizedPrediction(double normalizedPrediction) { + if (!this.initialisePerceptron){ + double meanY = perceptronsumY / perceptronYSeen; + double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen); + if (sdY > SD_THRESHOLD) + return normalizedPrediction * sdY + meanY; + else + return normalizedPrediction + meanY; + } + else + return normalizedPrediction; //Perceptron may have been "reseted". Use old weights to predict + + } + + public double prediction(double[] instanceValues) + { + double prediction = 0.0; + if(!this.initialisePerceptron) + { + for (int j = 0; j < instanceValues.length - 1; j++) { + prediction += this.weightAttribute[j] * instanceValues[j]; + } + prediction += this.weightAttribute[instanceValues.length - 1]; + } + return prediction; + } + + public double[] normalizedInstance(Instance inst){ + // Normalize Instance + double[] normalizedInstance = new double[inst.numAttributes()]; + for(int j = 0; j < inst.numAttributes() -1; j++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(j); + double mean = perceptronattributeStatistics.getValue(j) / perceptronYSeen; + double sd = computeSD(squaredperceptronattributeStatistics.getValue(j), perceptronattributeStatistics.getValue(j), perceptronYSeen); + if (sd > SD_THRESHOLD) + normalizedInstance[j] = (inst.value(instAttIndex) - mean)/ sd; + else + normalizedInstance[j] = inst.value(instAttIndex) - mean; + } + return normalizedInstance; + } + + public double computeSD(double squaredVal, double val, int size) { + if (size > 1) { + return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0)); + } + return 0.0; + } + + public double updateWeights(Instance inst, double learningRatio ){ + // Normalize Instance + double[] normalizedInstance = normalizedInstance(inst); + // Compute the Normalized Prediction of Perceptron + double normalizedPredict= prediction(normalizedInstance); + double normalizedY = normalizeActualClassValue(inst); + double sumWeights = 0.0; + double delta = normalizedY - normalizedPredict; + + for (int j = 0; j < inst.numAttributes() - 1; j++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(j); + if(inst.attribute(instAttIndex).isNumeric()) { + this.weightAttribute[j] += learningRatio * delta * normalizedInstance[j]; + sumWeights += Math.abs(this.weightAttribute[j]); + } + } + this.weightAttribute[inst.numAttributes() - 1] += learningRatio * delta; + sumWeights += Math.abs(this.weightAttribute[inst.numAttributes() - 1]); + if (sumWeights > inst.numAttributes()) { // Lasso regression + for (int j = 0; j < inst.numAttributes() - 1; j++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(j); + if(inst.attribute(instAttIndex).isNumeric()) { + this.weightAttribute[j] = this.weightAttribute[j] / sumWeights; + } + } + this.weightAttribute[inst.numAttributes() - 1] = this.weightAttribute[inst.numAttributes() - 1] / sumWeights; + } + + return denormalizedPrediction(normalizedPredict); + } + + public void normalizeWeights(){ + double sumWeights = 0.0; + + for (double aWeightAttribute : this.weightAttribute) { + sumWeights += Math.abs(aWeightAttribute); + } + for (int j = 0; j < this.weightAttribute.length; j++) { + this.weightAttribute[j] = this.weightAttribute[j] / sumWeights; + } + } + + private double normalizeActualClassValue(Instance inst) { + double meanY = perceptronsumY / perceptronYSeen; + double sdY = computeSD(squaredperceptronsumY, perceptronsumY, perceptronYSeen); + + double normalizedY; + if (sdY > SD_THRESHOLD){ + normalizedY = (inst.classValue() - meanY) / sdY; + }else{ + normalizedY = inst.classValue() - meanY; + } + return normalizedY; + } + + @Override + public boolean isRandomizable() { + return true; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return new double[]{this.prediction(inst)}; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return null; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + if(this.weightAttribute!=null){ + for(int i=0; i< this.weightAttribute.length-1; ++i) + { + if(this.weightAttribute[i]>=0 && i>0) + out.append(" +" + Math.round(this.weightAttribute[i]*1000)/1000.0 + " X" + i ); + else + out.append(" " + Math.round(this.weightAttribute[i]*1000)/1000.0 + " X" + i ); + } + if(this.weightAttribute[this.weightAttribute.length-1]>=0 ) + out.append(" +" + Math.round(this.weightAttribute[this.weightAttribute.length-1]*1000)/1000.0); + else + out.append(" " + Math.round(this.weightAttribute[this.weightAttribute.length-1]*1000)/1000.0); + } + } + + public void setLearningRatio(double learningRatio) { + this.learningRatio=learningRatio; + + } + + public double getCurrentError() + { + if (nError>0) + return accumulatedError/nError; + else + return Double.MAX_VALUE; + } + + public static class PerceptronData implements Serializable { + /** + * + */ + private static final long serialVersionUID = 6727623208744105082L; + + private boolean constantLearningRatioDecay; + // If the model (weights) should be reset or not + private boolean initialisePerceptron; + + private double nError; + private double fadingFactor; + private double originalLearningRatio; + private double learningRatio; + private double learningRateDecay; + private double accumulatedError; + private double perceptronsumY; + private double squaredperceptronsumY; + + // The Perception weights + private double[] weightAttribute; + + // Statistics used for error calculations + private DoubleVector perceptronattributeStatistics; + private DoubleVector squaredperceptronattributeStatistics; + + // The number of instances contributing to this model + private int perceptronInstancesSeen; + private int perceptronYSeen; + + public PerceptronData() { + + } + + public PerceptronData(Perceptron p) { + this.constantLearningRatioDecay = p.constantLearningRatioDecay; + this.initialisePerceptron = p.initialisePerceptron; + this.nError = p.nError; + this.fadingFactor = p.fadingFactor; + this.originalLearningRatio = p.originalLearningRatio; + this.learningRatio = p.learningRatio; + this.learningRateDecay = p.learningRateDecay; + this.accumulatedError = p.accumulatedError; + this.perceptronsumY = p.perceptronsumY; + this.squaredperceptronsumY = p.squaredperceptronsumY; + this.weightAttribute = p.weightAttribute; + this.perceptronattributeStatistics = p.perceptronattributeStatistics; + this.squaredperceptronattributeStatistics = p.squaredperceptronattributeStatistics; + this.perceptronInstancesSeen = p.perceptronInstancesSeen; + this.perceptronYSeen = p.perceptronYSeen; + } + + public Perceptron build() { + return new Perceptron(this); + } + + } + + + public static final class PerceptronSerializer extends Serializer<Perceptron>{ + + @Override + public void write(Kryo kryo, Output output, Perceptron p) { + kryo.writeObjectOrNull(output, new PerceptronData(p), PerceptronData.class); + } + + @Override + public Perceptron read(Kryo kryo, Input input, Class<Perceptron> type) { + PerceptronData perceptronData = kryo.readObjectOrNull(input, PerceptronData.class); + return perceptronData.build(); + } + } + +}
