http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java new file mode 100644 index 0000000..b85cf10 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/Rule.java @@ -0,0 +1,111 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.LinkedList; +import java.util.List; + +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.AbstractMOAObject; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate; + +/** + * The base class for "rule". + * Represents the most basic rule with and ID and a list of features (nodeList). + * + * @author Anh Thu Vu + * + */ +public abstract class Rule extends AbstractMOAObject { + private static final long serialVersionUID = 1L; + + protected int ruleNumberID; + + protected List<RuleSplitNode> nodeList; + + /* + * Constructor + */ + public Rule() { + this.nodeList = new LinkedList<RuleSplitNode>(); + } + + /* + * Rule ID + */ + public int getRuleNumberID() { + return ruleNumberID; + } + + public void setRuleNumberID(int ruleNumberID) { + this.ruleNumberID = ruleNumberID; + } + + /* + * RuleSplitNode list + */ + public List<RuleSplitNode> getNodeList() { + return nodeList; + } + + public void setNodeList(List<RuleSplitNode> nodeList) { + this.nodeList = nodeList; + } + + /* + * Covering + */ + public boolean isCovering(Instance inst) { + boolean isCovering = true; + for (RuleSplitNode node : nodeList) { + if (node.evaluate(inst) == false) { + isCovering = false; + break; + } + } + return isCovering; + } + + /* + * Add RuleSplitNode + */ + public boolean nodeListAdd(RuleSplitNode ruleSplitNode) { + //Check that the node is not already in the list + boolean isIncludedInNodeList = false; + boolean isUpdated=false; + for (RuleSplitNode node : nodeList) { + NumericAttributeBinaryRulePredicate nodeTest = (NumericAttributeBinaryRulePredicate) node.getSplitTest(); + NumericAttributeBinaryRulePredicate ruleSplitNodeTest = (NumericAttributeBinaryRulePredicate) ruleSplitNode.getSplitTest(); + if (nodeTest.isUsingSameAttribute(ruleSplitNodeTest)) { + isIncludedInNodeList = true; + if (nodeTest.isIncludedInRuleNode(ruleSplitNodeTest) == true) { //remove this line to keep the most recent attribute value + //replace the value + nodeTest.setAttributeValue(ruleSplitNodeTest); + isUpdated=true; //if is updated (i.e. an expansion happened) a new learning node should be created + } + } + } + if (isIncludedInNodeList == false) { + this.nodeList.add(ruleSplitNode); + } + return (!isIncludedInNodeList || isUpdated); + } +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java new file mode 100644 index 0000000..f52ac32 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveLearningNode.java @@ -0,0 +1,34 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * Interface for Rule's LearningNode that updates both statistics + * for expanding rule and computing predictions. + * + * @author Anh Thu Vu + * + */ +public interface RuleActiveLearningNode extends RulePassiveLearningNode { + + public boolean tryToExpand(double splitConfidence, double tieThreshold); + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java new file mode 100644 index 0000000..05079ed --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleActiveRegressionNode.java @@ -0,0 +1,318 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; + +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.core.AutoExpandVector; +import com.yahoo.labs.samoa.moa.core.DoubleVector; +import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.FIMTDDNumericAttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.splitcriteria.SDRSplitCriterionAMRules; +import com.yahoo.labs.samoa.moa.classifiers.rules.driftdetection.PageHinkleyFading; +import com.yahoo.labs.samoa.moa.classifiers.rules.driftdetection.PageHinkleyTest; + +/** + * LearningNode for regression rule that updates both statistics for + * expanding rule and computing predictions. + * + * @author Anh Thu Vu + * + */ +public class RuleActiveRegressionNode extends RuleRegressionNode implements RuleActiveLearningNode { + + /** + * + */ + private static final long serialVersionUID = 519854943188168546L; + + protected int splitIndex = 0; + + protected PageHinkleyTest pageHinckleyTest; + protected boolean changeDetection; + + protected double[] statisticsNewRuleActiveLearningNode = null; + protected double[] statisticsBranchSplit = null; + protected double[] statisticsOtherBranchSplit; + + protected AttributeSplitSuggestion bestSuggestion = null; + + protected AutoExpandVector<AttributeClassObserver> attributeObservers = new AutoExpandVector<>(); + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + + /* + * Simple setters & getters + */ + public int getSplitIndex() { + return splitIndex; + } + + public void setSplitIndex(int splitIndex) { + this.splitIndex = splitIndex; + } + + public double[] getStatisticsOtherBranchSplit() { + return statisticsOtherBranchSplit; + } + + public void setStatisticsOtherBranchSplit(double[] statisticsOtherBranchSplit) { + this.statisticsOtherBranchSplit = statisticsOtherBranchSplit; + } + + public double[] getStatisticsBranchSplit() { + return statisticsBranchSplit; + } + + public void setStatisticsBranchSplit(double[] statisticsBranchSplit) { + this.statisticsBranchSplit = statisticsBranchSplit; + } + + public double[] getStatisticsNewRuleActiveLearningNode() { + return statisticsNewRuleActiveLearningNode; + } + + public void setStatisticsNewRuleActiveLearningNode( + double[] statisticsNewRuleActiveLearningNode) { + this.statisticsNewRuleActiveLearningNode = statisticsNewRuleActiveLearningNode; + } + + public AttributeSplitSuggestion getBestSuggestion() { + return bestSuggestion; + } + + public void setBestSuggestion(AttributeSplitSuggestion bestSuggestion) { + this.bestSuggestion = bestSuggestion; + } + + /* + * Constructor with builder + */ + public RuleActiveRegressionNode() { + super(); + } + public RuleActiveRegressionNode(ActiveRule.Builder builder) { + super(builder.statistics); + this.changeDetection = builder.changeDetection; + if (!builder.changeDetection) { + this.pageHinckleyTest = new PageHinkleyFading(builder.threshold, builder.alpha); + } + this.predictionFunction = builder.predictionFunction; + this.learningRatio = builder.learningRatio; + this.ruleNumberID = builder.id; + this.numericObserver = builder.numericObserver; + + this.perceptron = new Perceptron(); + this.perceptron.prepareForUse(); + this.perceptron.originalLearningRatio = builder.learningRatio; + this.perceptron.constantLearningRatioDecay = builder.constantLearningRatioDecay; + + + if(this.predictionFunction!=1) + { + this.targetMean = new TargetMean(); + if (builder.statistics[0]>0) + this.targetMean.reset(builder.statistics[1]/builder.statistics[0],(long)builder.statistics[0]); + } + this.predictionFunction = builder.predictionFunction; + if (builder.statistics!=null) + this.nodeStatistics=new DoubleVector(builder.statistics); + } + + /* + * Update with input instance + */ + public boolean updatePageHinckleyTest(double error) { + boolean changeDetected = false; + if (!this.changeDetection) { + changeDetected = pageHinckleyTest.update(error); + } + return changeDetected; + } + + public boolean updateChangeDetection(double error) { + return !changeDetection && pageHinckleyTest.update(error); + } + + @Override + public void updateStatistics(Instance inst) { + // Update the statistics for this node + // number of instances passing through the node + nodeStatistics.addToValue(0, 1); + // sum of y values + nodeStatistics.addToValue(1, inst.classValue()); + // sum of squared y values + nodeStatistics.addToValue(2, inst.classValue()*inst.classValue()); + + for (int i = 0; i < inst.numAttributes() - 1; i++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs == null) { + // At this stage all nominal attributes are ignored + if (inst.attribute(instAttIndex).isNumeric()) //instAttIndex + { + obs = newNumericClassObserver(); + this.attributeObservers.set(i, obs); + } + } + if (obs != null) { + ((FIMTDDNumericAttributeClassObserver) obs).observeAttributeClass(inst.value(instAttIndex), inst.classValue(), inst.weight()); + } + } + + this.perceptron.trainOnInstance(inst); + if (this.predictionFunction != 1) { //Train target mean if prediction function is not Perceptron + this.targetMean.trainOnInstance(inst); + } + } + + protected AttributeClassObserver newNumericClassObserver() { + //return new FIMTDDNumericAttributeClassObserver(); + //return new FIMTDDNumericAttributeClassLimitObserver(); + //return (AttributeClassObserver)((AttributeClassObserver)this.numericObserverOption.getPreMaterializedObject()).copy(); + FIMTDDNumericAttributeClassLimitObserver newObserver = new FIMTDDNumericAttributeClassLimitObserver(); + newObserver.setMaxNodes(numericObserver.getMaxNodes()); + return newObserver; + } + + /* + * Init after being split from oldLearningNode + */ + public void initialize(RuleRegressionNode oldLearningNode) { + if(oldLearningNode.perceptron!=null) + { + this.perceptron=new Perceptron(oldLearningNode.perceptron); + this.perceptron.resetError(); + this.perceptron.setLearningRatio(oldLearningNode.learningRatio); + } + + if(oldLearningNode.targetMean!=null) + { + this.targetMean= new TargetMean(oldLearningNode.targetMean); + this.targetMean.resetError(); + } + //reset statistics + this.nodeStatistics.setValue(0, 0); + this.nodeStatistics.setValue(1, 0); + this.nodeStatistics.setValue(2, 0); + } + + /* + * Expand + */ + @Override + public boolean tryToExpand(double splitConfidence, double tieThreshold) { + + // splitConfidence. Hoeffding Bound test parameter. + // tieThreshold. Hoeffding Bound test parameter. + SplitCriterion splitCriterion = new SDRSplitCriterionAMRules(); + //SplitCriterion splitCriterion = new SDRSplitCriterionAMRulesNode();//JD for assessing only best branch + + // Using this criterion, find the best split per attribute and rank the results + AttributeSplitSuggestion[] bestSplitSuggestions = this.getBestSplitSuggestions(splitCriterion); + Arrays.sort(bestSplitSuggestions); + // Declare a variable to determine if any of the splits should be performed + boolean shouldSplit = false; + + // If only one split was returned, use it + if (bestSplitSuggestions.length < 2) { + shouldSplit = ((bestSplitSuggestions.length > 0) && (bestSplitSuggestions[0].merit > 0)); + bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + } // Otherwise, consider which of the splits proposed may be worth trying + else { + // Determine the hoeffding bound value, used to select how many instances should be used to make a test decision + // to feel reasonably confident that the test chosen by this sample is the same as what would be chosen using infinite examples + double hoeffdingBound = computeHoeffdingBound(1, splitConfidence, getInstancesSeen()); + // Determine the top two ranked splitting suggestions + bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + AttributeSplitSuggestion secondBestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 2]; + + // If the upper bound of the sample mean for the ratio of SDR(best suggestion) to SDR(second best suggestion), + // as determined using the hoeffding bound, is less than 1, then the true mean is also less than 1, and thus at this + // particular moment of observation the bestSuggestion is indeed the best split option with confidence 1-delta, and + // splitting should occur. + // Alternatively, if two or more splits are very similar or identical in terms of their splits, then a threshold limit + // (default 0.05) is applied to the hoeffding bound; if the hoeffding bound is smaller than this limit then the two + // competing attributes are equally good, and the split will be made on the one with the higher SDR value. + + if (bestSuggestion.merit > 0) { + if ((((secondBestSuggestion.merit / bestSuggestion.merit) + hoeffdingBound) < 1) + || (hoeffdingBound < tieThreshold)) { + shouldSplit = true; + } + } + } + + if (shouldSplit) { + AttributeSplitSuggestion splitDecision = bestSplitSuggestions[bestSplitSuggestions.length - 1]; + double minValue = Double.MAX_VALUE; + double[] branchMerits = SDRSplitCriterionAMRules.computeBranchSplitMerits(bestSuggestion.resultingClassDistributions); + + for (int i = 0; i < bestSuggestion.numSplits(); i++) { + double value = branchMerits[i]; + if (value < minValue) { + minValue = value; + splitIndex = i; + statisticsNewRuleActiveLearningNode = bestSuggestion.resultingClassDistributionFromSplit(i); + } + } + statisticsBranchSplit = splitDecision.resultingClassDistributionFromSplit(splitIndex); + statisticsOtherBranchSplit = bestSuggestion.resultingClassDistributionFromSplit(splitIndex == 0 ? 1 : 0); + + } + return shouldSplit; + } + + public AutoExpandVector<AttributeClassObserver> getAttributeObservers() { + return this.attributeObservers; + } + + public AttributeSplitSuggestion[] getBestSplitSuggestions(SplitCriterion criterion) { + + List<AttributeSplitSuggestion> bestSuggestions = new LinkedList<AttributeSplitSuggestion>(); + + // Set the nodeStatistics up as the preSplitDistribution, rather than the observedClassDistribution + double[] nodeSplitDist = this.nodeStatistics.getArrayCopy(); + for (int i = 0; i < this.attributeObservers.size(); i++) { + AttributeClassObserver obs = this.attributeObservers.get(i); + if (obs != null) { + + // AT THIS STAGE NON-NUMERIC ATTRIBUTES ARE IGNORED + AttributeSplitSuggestion bestSuggestion = null; + if (obs instanceof FIMTDDNumericAttributeClassObserver) { + bestSuggestion = obs.getBestEvaluatedSplitSuggestion(criterion, nodeSplitDist, i, true); + } + + if (bestSuggestion != null) { + bestSuggestions.add(bestSuggestion); + } + } + } + return bestSuggestions.toArray(new AttributeSplitSuggestion[bestSuggestions.size()]); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java new file mode 100644 index 0000000..4934225 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveLearningNode.java @@ -0,0 +1,33 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * Interface for Rule's LearningNode that does not update + * statistics for expanding rule. It only updates statistics for + * computing predictions. + * + * @author Anh Thu Vu + * + */ +public interface RulePassiveLearningNode { + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java new file mode 100644 index 0000000..674e482 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RulePassiveRegressionNode.java @@ -0,0 +1,76 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.core.DoubleVector; + +/** + * LearningNode for regression rule that does not update + * statistics for expanding rule. It only updates statistics for + * computing predictions. + * + * @author Anh Thu Vu + * + */ +public class RulePassiveRegressionNode extends RuleRegressionNode implements RulePassiveLearningNode { + + /** + * + */ + private static final long serialVersionUID = 3720878438856489690L; + + public RulePassiveRegressionNode (double[] statistics) { + super(statistics); + } + + public RulePassiveRegressionNode() { + super(); + } + + public RulePassiveRegressionNode(RuleRegressionNode activeLearningNode) { + this.predictionFunction = activeLearningNode.predictionFunction; + this.ruleNumberID = activeLearningNode.ruleNumberID; + this.nodeStatistics = new DoubleVector(activeLearningNode.nodeStatistics); + this.learningRatio = activeLearningNode.learningRatio; + this.perceptron = new Perceptron(activeLearningNode.perceptron, true); + this.targetMean = new TargetMean(activeLearningNode.targetMean); + } + + /* + * Update with input instance + */ + @Override + public void updateStatistics(Instance inst) { + // Update the statistics for this node + // number of instances passing through the node + nodeStatistics.addToValue(0, 1); + // sum of y values + nodeStatistics.addToValue(1, inst.classValue()); + // sum of squared y values + nodeStatistics.addToValue(2, inst.classValue()*inst.classValue()); + + this.perceptron.trainOnInstance(inst); + if (this.predictionFunction != 1) { //Train target mean if prediction function is not Perceptron + this.targetMean.trainOnInstance(inst); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java new file mode 100644 index 0000000..45f5719 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleRegressionNode.java @@ -0,0 +1,292 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.core.DoubleVector; + +/** + * The base class for LearningNode for regression rule. + * + * @author Anh Thu Vu + * + */ +public abstract class RuleRegressionNode implements Serializable { + + private static final long serialVersionUID = 9129659494380381126L; + + protected int predictionFunction; + protected int ruleNumberID; + // The statistics for this node: + // Number of instances that have reached it + // Sum of y values + // Sum of squared y values + protected DoubleVector nodeStatistics; + + protected Perceptron perceptron; + protected TargetMean targetMean; + protected double learningRatio; + + /* + * Simple setters & getters + */ + public Perceptron getPerceptron() { + return perceptron; + } + + public void setPerceptron(Perceptron perceptron) { + this.perceptron = perceptron; + } + + public TargetMean getTargetMean() { + return targetMean; + } + + public void setTargetMean(TargetMean targetMean) { + this.targetMean = targetMean; + } + + /* + * Create a new RuleRegressionNode + */ + public RuleRegressionNode(double[] initialClassObservations) { + this.nodeStatistics = new DoubleVector(initialClassObservations); + } + + public RuleRegressionNode() { + this(new double[0]); + } + + /* + * Update statistics with input instance + */ + public abstract void updateStatistics(Instance instance); + + /* + * Predictions + */ + public double[] getPrediction(Instance instance) { + int predictionMode = this.getLearnerToUse(this.predictionFunction); + return getPrediction(instance, predictionMode); + } + + public double[] getSimplePrediction() { + if( this.targetMean!=null) + return this.targetMean.getVotesForInstance(); + else + return new double[]{0}; + } + + public double[] getPrediction(Instance instance, int predictionMode) { + double[] ret; + if (predictionMode == 1) + ret=this.perceptron.getVotesForInstance(instance); + else + ret=this.targetMean.getVotesForInstance(instance); + return ret; + } + + public double getNormalizedPrediction(Instance instance) { + double res; + double [] aux; + switch (this.predictionFunction) { + //perceptron - 1 + case 1: + res=this.perceptron.normalizedPrediction(instance); + break; + //target mean - 2 + case 2: + aux=this.targetMean.getVotesForInstance(); + res=normalize(aux[0]); + break; + //adaptive - 0 + case 0: + int predictionMode = this.getLearnerToUse(0); + if(predictionMode == 1) + { + res=this.perceptron.normalizedPrediction(instance); + } + else{ + aux=this.targetMean.getVotesForInstance(instance); + res = normalize(aux[0]); + } + break; + default: + throw new UnsupportedOperationException("Prediction mode not in range."); + } + return res; + } + + /* + * Get learner mode + */ + public int getLearnerToUse(int predMode) { + int predictionMode = predMode; + if (predictionMode == 0) { + double perceptronError= this.perceptron.getCurrentError(); + double meanTargetError =this.targetMean.getCurrentError(); + if (perceptronError < meanTargetError) + predictionMode = 1; //PERCEPTRON + else + predictionMode = 2; //TARGET MEAN + } + return predictionMode; + } + + /* + * Error and change detection + */ + public double computeError(Instance instance) { + double normalizedPrediction = getNormalizedPrediction(instance); + double normalizedClassValue = normalize(instance.classValue()); + return Math.abs(normalizedClassValue - normalizedPrediction); + } + + public double getCurrentError() { + double error; + if (this.perceptron!=null){ + if (targetMean==null) + error=perceptron.getCurrentError(); + else{ + double errorP=perceptron.getCurrentError(); + double errorTM=targetMean.getCurrentError(); + error = (errorP<errorTM) ? errorP : errorTM; + } + } + else + error=Double.MAX_VALUE; + return error; + } + + /* + * no. of instances seen + */ + public long getInstancesSeen() { + if (nodeStatistics != null) { + return (long)this.nodeStatistics.getValue(0); + } else { + return 0; + } + } + + public DoubleVector getNodeStatistics(){ + return this.nodeStatistics; + } + + /* + * Anomaly detection + */ + public boolean isAnomaly(Instance instance, + double uniVariateAnomalyProbabilityThreshold, + double multiVariateAnomalyProbabilityThreshold, + int numberOfInstanceesForAnomaly) { + //AMRUles is equipped with anomaly detection. If on, compute the anomaly value. + long perceptronIntancesSeen=this.perceptron.getInstancesSeen(); + if ( perceptronIntancesSeen>= numberOfInstanceesForAnomaly) { + double attribSum; + double attribSquaredSum; + double D = 0.0; + double N = 0.0; + double anomaly; + + for (int x = 0; x < instance.numAttributes() - 1; x++) { + // Perceptron is initialized each rule. + // this is a local anomaly. + int instAttIndex = modelAttIndexToInstanceAttIndex(x, instance); + attribSum = this.perceptron.perceptronattributeStatistics.getValue(x); + attribSquaredSum = this.perceptron.squaredperceptronattributeStatistics.getValue(x); + double mean = attribSum / perceptronIntancesSeen; + double sd = computeSD(attribSquaredSum, attribSum, perceptronIntancesSeen); + double probability = computeProbability(mean, sd, instance.value(instAttIndex)); + + if (probability > 0.0) { + D = D + Math.abs(Math.log(probability)); + if (probability < uniVariateAnomalyProbabilityThreshold) {//0.10 + N = N + Math.abs(Math.log(probability)); + } + } + } + + anomaly = 0.0; + if (D != 0.0) { + anomaly = N / D; + } + if (anomaly >= multiVariateAnomalyProbabilityThreshold) { + //debuganomaly(instance, + // uniVariateAnomalyProbabilityThreshold, + // multiVariateAnomalyProbabilityThreshold, + // anomaly); + return true; + } + } + return false; + } + + /* + * Helpers + */ + public static double computeProbability(double mean, double sd, double value) { + double probability = 0.0; + + if (sd > 0.0) { + double k = (Math.abs(value - mean) / sd); // One tailed variant of Chebyshev's inequality + probability= 1.0 / (1+k*k); + } + + return probability; + } + + public static double computeHoeffdingBound(double range, double confidence, double n) { + return Math.sqrt(((range * range) * Math.log(1.0 / confidence)) / (2.0 * n)); + } + + private double normalize(double value) { + double meanY = this.nodeStatistics.getValue(1)/this.nodeStatistics.getValue(0); + double sdY = computeSD(this.nodeStatistics.getValue(2), this.nodeStatistics.getValue(1), (long)this.nodeStatistics.getValue(0)); + double normalizedY = 0.0; + if (sdY > 0.0000001) { + normalizedY = (value - meanY) / (sdY); + } + return normalizedY; + } + + + public double computeSD(double squaredVal, double val, long size) { + if (size > 1) { + return Math.sqrt((squaredVal - ((val * val) / size)) / (size - 1.0)); + } + return 0.0; + } + + /** + * Gets the index of the attribute in the instance, + * given the index of the attribute in the learner. + * + * @param index the index of the attribute in the learner + * @param inst the instance + * @return the index in the instance + */ + protected static int modelAttIndexToInstanceAttIndex(int index, Instance inst) { + return index<= inst.classIndex() ? index : index + 1; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java new file mode 100644 index 0000000..28f4890 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/RuleSplitNode.java @@ -0,0 +1,66 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.Predicate; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.conditionaltests.NumericAttributeBinaryRulePredicate; +import com.yahoo.labs.samoa.learners.classifiers.trees.SplitNode; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Represent a feature of rules (an element of ruleÅ nodeList). + * + * @author Anh Thu Vu + * + */ +public class RuleSplitNode extends SplitNode { + + protected double lastTargetMean; + protected int operatorObserver; + + private static final long serialVersionUID = 1L; + + public InstanceConditionalTest getSplitTest() { + return this.splitTest; + } + + /** + * Create a new RuleSplitNode + */ + public RuleSplitNode() { + this(null, new double[0]); + } + public RuleSplitNode(InstanceConditionalTest splitTest, double[] classObservations) { + super(splitTest, classObservations); + } + + public RuleSplitNode getACopy() { + InstanceConditionalTest splitTest = new NumericAttributeBinaryRulePredicate((NumericAttributeBinaryRulePredicate) this.getSplitTest()); + return new RuleSplitNode(splitTest, this.getObservedClassDistribution()); + } + + public boolean evaluate(Instance instance) { + Predicate predicate = (Predicate) this.splitTest; + return predicate.evaluate(instance); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java new file mode 100644 index 0000000..902acf0 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/common/TargetMean.java @@ -0,0 +1,220 @@ +/* + * TargetMean.java + * Copyright (C) 2014 University of Porto, Portugal + * @author J. Duarte, A. Bifet, J. Gama + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + */ +package com.yahoo.labs.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * Prediction scheme using TargetMean: + * TargetMean - Returns the mean of the target variable of the training instances + * + * @author Joao Duarte + * + * */ + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.github.javacliparser.FloatOption; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.classifiers.AbstractClassifier; +import com.yahoo.labs.samoa.moa.classifiers.Regressor; +import com.yahoo.labs.samoa.moa.core.Measurement; +import com.yahoo.labs.samoa.moa.core.StringUtils; + +public class TargetMean extends AbstractClassifier implements Regressor { + + /** + * + */ + protected long n; + protected double sum; + protected double errorSum; + protected double nError; + private double fadingErrorFactor; + + private static final long serialVersionUID = 7152547322803559115L; + + public FloatOption fadingErrorFactorOption = new FloatOption( + "fadingErrorFactor", 'e', + "Fading error factor for the TargetMean accumulated error", 0.99, 0, 1); + + @Override + public boolean isRandomizable() { + return false; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return getVotesForInstance(); + } + + public double[] getVotesForInstance() { + double[] currentMean=new double[1]; + if (n>0) + currentMean[0]=sum/n; + else + currentMean[0]=0; + return currentMean; + } + + @Override + public void resetLearningImpl() { + sum=0; + n=0; + errorSum=Double.MAX_VALUE; + nError=0; + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + updateAccumulatedError(inst); + ++this.n; + this.sum+=inst.classValue(); + } + protected void updateAccumulatedError(Instance inst){ + double mean=0; + nError=1+fadingErrorFactor*nError; + if(n>0) + mean=sum/n; + errorSum=Math.abs(inst.classValue()-mean)+fadingErrorFactor*errorSum; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return null; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + StringUtils.appendIndented(out, indent, "Current Mean: " + this.sum/this.n); + StringUtils.appendNewline(out); + + } + /* JD + * Resets the learner but initializes with a starting point + * */ + public void reset(double currentMean, long numberOfInstances) { + this.sum=currentMean*numberOfInstances; + this.n=numberOfInstances; + this.resetError(); + } + + /* JD + * Resets the learner but initializes with a starting point + * */ + public double getCurrentError(){ + if(this.nError>0) + return this.errorSum/this.nError; + else + return Double.MAX_VALUE; + } + + public TargetMean(TargetMean t) { + super(); + this.n = t.n; + this.sum = t.sum; + this.errorSum = t.errorSum; + this.nError = t.nError; + this.fadingErrorFactor = t.fadingErrorFactor; + this.fadingErrorFactorOption = t.fadingErrorFactorOption; + } + + public TargetMean(TargetMeanData td) { + this(); + this.n = td.n; + this.sum = td.sum; + this.errorSum = td.errorSum; + this.nError = td.nError; + this.fadingErrorFactor = td.fadingErrorFactor; + this.fadingErrorFactorOption.setValue(td.fadingErrorFactorOptionValue); + } + + public TargetMean() { + super(); + fadingErrorFactor=fadingErrorFactorOption.getValue(); + } + + public void resetError() { + this.errorSum=0; + this.nError=0; + } + + public static class TargetMeanData { + private long n; + private double sum; + private double errorSum; + private double nError; + private double fadingErrorFactor; + private double fadingErrorFactorOptionValue; + + public TargetMeanData() { + + } + + public TargetMeanData(TargetMean tm) { + this.n = tm.n; + this.sum = tm.sum; + this.errorSum = tm.errorSum; + this.nError = tm.nError; + this.fadingErrorFactor = tm.fadingErrorFactor; + if (tm.fadingErrorFactorOption != null) + this.fadingErrorFactorOptionValue = tm.fadingErrorFactorOption.getValue(); + else + this.fadingErrorFactorOptionValue = 0.99; + } + + public TargetMean build() { + return new TargetMean(this); + } + } + + public static final class TargetMeanSerializer extends Serializer<TargetMean>{ + + @Override + public void write(Kryo kryo, Output output, TargetMean t) { + kryo.writeObjectOrNull(output, new TargetMeanData(t), TargetMeanData.class); + } + + @Override + public TargetMean read(Kryo kryo, Input input, Class<TargetMean> type) { + TargetMeanData data = kryo.readObjectOrNull(input, TargetMeanData.class); + return data.build(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java new file mode 100644 index 0000000..54a4006 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java @@ -0,0 +1,334 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.InstanceContentEvent; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.Perceptron; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * Default Rule Learner Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRDefaultRuleProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 23702084591044447L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRDefaultRuleProcessor.class); + + private int processorId; + + // Default rule + protected transient ActiveRule defaultRule; + protected transient int ruleNumberID; + protected transient double[] statistics; + + // SAMOA Stream + private Stream ruleStream; + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + + /* + * Constructor + */ + public AMRDefaultRuleProcessor (Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.numericObserver = builder.numericObserver; + } + + @Override + public boolean process(ContentEvent event) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + // predict + if (instanceEvent.isTesting()) { + this.predictOnInstance(instanceEvent); + } + + // train + if (instanceEvent.isTraining()) { + this.trainOnInstance(instanceEvent); + } + + return false; + } + + /* + * Prediction + */ + private void predictOnInstance (InstanceContentEvent instanceEvent) { + double [] vote=defaultRule.getPrediction(instanceEvent.getInstance()); + ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); + resultStream.put(rce); + } + + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){ + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + /* + * Training + */ + private void trainOnInstance (InstanceContentEvent instanceEvent) { + this.trainOnInstanceImpl(instanceEvent.getInstance()); + } + public void trainOnInstanceImpl(Instance instance) { + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule=newRule(defaultRule.getRuleNumberID(),(RuleActiveRegressionNode)defaultRule.getLearningNode(), + ((RuleActiveRegressionNode)defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); //other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + // send out the new rule + sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); + defaultRule=newDefaultRule; + } + } + } + + /* + * Create new rules + */ + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r=newRule(ID); + + if (node!=null) + { + if(node.getPerceptron()!=null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics==null) + { + double mean; + if(node.getNodeStatistics().getValue(0)>0){ + mean=node.getNodeStatistics().getValue(1)/node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics!=null && ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean()!=null) + { + double mean; + if(statistics[0]>0){ + mean=statistics[1]/statistics[0]; + ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean().reset(mean, (long)statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r=new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics= new double[]{0.0,0,0}; + this.ruleNumberID=0; + this.defaultRule = newRule(++this.ruleNumberID); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRDefaultRuleProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.ruleStream = oldProcessor.ruleStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendAddRuleEvent(int ruleID, ActiveRule rule) { + RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); + this.ruleStream.put(rce); + } + + /* + * Output streams + */ + public void setRuleStream(Stream ruleStream) { + this.ruleStream = ruleStream; + } + + public Stream getRuleStream() { + return this.ruleStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + + private Instances dataset; + + public Builder(Instances dataset){ + this.dataset = dataset; + } + + public Builder(AMRDefaultRuleProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.numericObserver = processor.numericObserver; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public AMRDefaultRuleProcessor build() { + return new AMRDefaultRuleProcessor(this); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java new file mode 100644 index 0000000..8ec118d --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java @@ -0,0 +1,259 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.LearningRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleSplitNode; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * Learner Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRLearnerProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -2302897295090248013L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRLearnerProcessor.class); + + private int processorId; + + private transient List<ActiveRule> ruleSet; + + private Stream outputStream; + + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + public AMRLearnerProcessor(Builder builder) { + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + } + + @Override + public boolean process(ContentEvent event) { + if (event instanceof AssignmentContentEvent) { + AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; + trainRuleOnInstance(attrContentEvent.getRuleNumberID(),attrContentEvent.getInstance()); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent ruleContentEvent = (RuleContentEvent) event; + if (!ruleContentEvent.isRemoving()) { + addRule(ruleContentEvent.getRule()); + } + } + + return false; + } + + /* + * Process input instances + */ + private void trainRuleOnInstance(int ruleID, Instance instance) { + //System.out.println("Processor:"+this.processorId+": Rule:"+ruleID+" -> Counter="+counter); + Iterator<ActiveRule> ruleIterator= this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.getRuleNumberID() == ruleID) { + // Check (again) for coverage + if (rule.isCovering(instance) == true) { + double error = rule.computeError(instance); //Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode)rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + + this.sendRemoveRuleEvent(ruleID); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold) ) { + rule.split(); + + // expanded: update Aggregator with new/updated predicate + this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), + (RuleActiveRegressionNode)rule.getLearningNode()); + } + + } + + } + } + + return; + } + } + } + + private boolean isAnomaly(Instance instance, LearningRule rule) { + //AMRUles is equipped with anomaly detection. If on, compute the anomaly value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false){ + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + private void sendRemoveRuleEvent(int ruleID) { + RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); + this.outputStream.put(rce); + } + + private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { + this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); + } + + /* + * Process control message (regarding adding or removing rules) + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(rule); + return true; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList<ActiveRule>(); + } + + @Override + public Processor newProcessor(Processor p) { + AMRLearnerProcessor oldProcessor = (AMRLearnerProcessor)p; + AMRLearnerProcessor newProcessor = + new AMRLearnerProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + /* + * Builder + */ + public static class Builder { + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private Instances dataset; + + public Builder(Instances dataset){ + this.dataset = dataset; + } + + public Builder(AMRLearnerProcessor processor) { + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public AMRLearnerProcessor build() { + return new AMRLearnerProcessor(this); + } + } + + /* + * Output stream + */ + public void setOutputStream(Stream stream) { + this.outputStream = stream; + } + + public Stream getOutputStream() { + return this.outputStream; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java new file mode 100644 index 0000000..38a0be1 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java @@ -0,0 +1,362 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.InstanceContentEvent; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.LearningRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.PassiveRule; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote; +import com.yahoo.labs.samoa.topology.Stream; +import java.util.LinkedList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Model Aggregator Processor (HAMR). + * @author Anh Thu Vu + * + */ +public class AMRRuleSetProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -6544096255649379334L; + private static final Logger logger = LoggerFactory.getLogger(AMRRuleSetProcessor.class); + + private int processorId; + + // Rules & default rule + protected transient List<PassiveRule> ruleSet; + + // SAMOA Stream + private Stream statisticsStream; + private Stream resultStream; + private Stream defaultRuleStream; + + // Options + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected int voteType; + + /* + * Constructor + */ + public AMRRuleSetProcessor (Builder builder) { + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.voteType = builder.voteType; + } + /* (non-Javadoc) + * @see com.yahoo.labs.samoa.core.Processor#process(com.yahoo.labs.samoa.core.ContentEvent) + */ + @Override + public boolean process(ContentEvent event) { + if (event instanceof InstanceContentEvent) { + this.processInstanceEvent((InstanceContentEvent) event); + } + else if (event instanceof PredicateContentEvent) { + PredicateContentEvent pce = (PredicateContentEvent) event; + if (pce.getRuleSplitNode() == null) { + this.updateLearningNode(pce); + } + else { + this.updateRuleSplitNode(pce); + } + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent rce = (RuleContentEvent) event; + if (rce.isRemoving()) { + this.removeRule(rce.getRuleNumberID()); + } + else { + addRule(rce.getRule()); + } + } + return true; + } + + private void processInstanceEvent(InstanceContentEvent instanceEvent) { + Instance instance = instanceEvent.getInstance(); + boolean predictionCovered = false; + boolean trainingCovered = false; + boolean continuePrediction = instanceEvent.isTesting(); + boolean continueTraining = instanceEvent.isTraining(); + + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + for (PassiveRule aRuleSet : this.ruleSet) { + if (!continuePrediction && !continueTraining) + break; + + if (aRuleSet.isCovering(instance)) { + predictionCovered = true; + + if (continuePrediction) { + double[] vote = aRuleSet.getPrediction(instance); + double error = aRuleSet.getCurrentError(); + errorWeightedVote.addVote(vote, error); + if (!this.unorderedRules) continuePrediction = false; + } + + if (continueTraining) { + if (!isAnomaly(instance, aRuleSet)) { + trainingCovered = true; + aRuleSet.updateStatistics(instance); + + // Send instance to statistics PIs + sendInstanceToRule(instance, aRuleSet.getRuleNumberID()); + + if (!this.unorderedRules) continueTraining = false; + } + } + } + } + + if (predictionCovered) { + // Combined prediction + ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); + resultStream.put(rce); + } + + boolean defaultPrediction = instanceEvent.isTesting() && !predictionCovered; + boolean defaultTraining = instanceEvent.isTraining() && !trainingCovered; + if (defaultPrediction || defaultTraining) { + instanceEvent.setTesting(defaultPrediction); + instanceEvent.setTraining(defaultTraining); + this.defaultRuleStream.put(instanceEvent); + } + } + + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){ + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + public ErrorWeightedVote newErrorWeightedVote() { + // TODO: do a reset instead of init a new object + if (voteType == 1) + return new UniformWeightedVote(); + return new InverseErrorWeightedVote(); + } + + /** + * Method to verify if the instance is an anomaly. + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, LearningRule rule) { + //AMRUles is equipped with anomaly detection. If on, compute the anomaly value. + boolean isAnomaly = false; + if (!this.noAnomalyDetection){ + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Add predicate/RuleSplitNode for a rule + */ + private void updateRuleSplitNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule:ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + rule.nodeListAdd(pce.getRuleSplitNode()); + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + private void updateLearningNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule:ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + /* + * Add new rule/Remove rule + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(new PassiveRule(rule)); + return true; + } + + private void removeRule(int ruleID) { + for (PassiveRule rule:ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + ruleSet.remove(rule); + break; + } + } + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList<PassiveRule>(); + + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRRuleSetProcessor oldProcessor = (AMRRuleSetProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRRuleSetProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.statisticsStream = oldProcessor.statisticsStream; + newProcessor.defaultRuleStream = oldProcessor.defaultRuleStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendInstanceToRule(Instance instance, int ruleID) { + AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); + this.statisticsStream.put(ace); + } + + /* + * Output streams + */ + public void setStatisticsStream(Stream statisticsStream) { + this.statisticsStream = statisticsStream; + } + + public Stream getStatisticsStream() { + return this.statisticsStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + public Stream getDefaultRuleStream() { + return this.defaultRuleStream; + } + + public void setDefaultRuleStream(Stream defaultRuleStream) { + this.defaultRuleStream = defaultRuleStream; + } + + /* + * Builder + */ + public static class Builder { + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + +// private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private int voteType; + + private Instances dataset; + + public Builder(Instances dataset){ + this.dataset = dataset; + } + + public Builder(AMRRuleSetProcessor processor) { + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.voteType = processor.voteType; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder voteType(int voteType) { + this.voteType = voteType; + return this; + } + + public AMRRuleSetProcessor build() { + return new AMRRuleSetProcessor(this); + } + } + +}
