http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java new file mode 100644 index 0000000..debe912 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java @@ -0,0 +1,525 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.InstanceContentEvent; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.LearningRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.PassiveRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.Perceptron; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote; +import com.yahoo.labs.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * Model Aggregator Processor (VAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRulesAggregatorProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 6303385725332704251L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesAggregatorProcessor.class); + + private int processorId; + + // Rules & default rule + protected transient List<PassiveRule> ruleSet; + protected transient ActiveRule defaultRule; + protected transient int ruleNumberID; + protected transient double[] statistics; + + // SAMOA Stream + private Stream statisticsStream; + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + protected int voteType; + + /* + * Constructor + */ + public AMRulesAggregatorProcessor (Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.numericObserver = builder.numericObserver; + this.voteType = builder.voteType; + } + + /* + * Process + */ + @Override + public boolean process(ContentEvent event) { + if (event instanceof InstanceContentEvent) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + this.processInstanceEvent(instanceEvent); + } + else if (event instanceof PredicateContentEvent) { + this.updateRuleSplitNode((PredicateContentEvent) event); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent rce = (RuleContentEvent) event; + if (rce.isRemoving()) { + this.removeRule(rce.getRuleNumberID()); + } + } + + return true; + } + + // Merge predict and train so we only check for covering rules one time + private void processInstanceEvent(InstanceContentEvent instanceEvent) { + Instance instance = instanceEvent.getInstance(); + boolean predictionCovered = false; + boolean trainingCovered = false; + boolean continuePrediction = instanceEvent.isTesting(); + boolean continueTraining = instanceEvent.isTraining(); + + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + Iterator<PassiveRule> ruleIterator= this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + if (!continuePrediction && !continueTraining) + break; + + PassiveRule rule = ruleIterator.next(); + + if (rule.isCovering(instance) == true){ + predictionCovered = true; + + if (continuePrediction) { + double [] vote=rule.getPrediction(instance); + double error= rule.getCurrentError(); + errorWeightedVote.addVote(vote,error); + if (!this.unorderedRules) continuePrediction = false; + } + + if (continueTraining) { + if (!isAnomaly(instance, rule)) { + trainingCovered = true; + rule.updateStatistics(instance); + // Send instance to statistics PIs + sendInstanceToRule(instance, rule.getRuleNumberID()); + + if (!this.unorderedRules) continueTraining = false; + } + } + } + } + + if (predictionCovered) { + // Combined prediction + ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); + resultStream.put(rce); + } + else if (instanceEvent.isTesting()) { + // predict with default rule + double [] vote=defaultRule.getPrediction(instance); + ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); + resultStream.put(rce); + } + + if (!trainingCovered && instanceEvent.isTraining()) { + // train default rule with this instance + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule=newRule(defaultRule.getRuleNumberID(),(RuleActiveRegressionNode)defaultRule.getLearningNode(), + ((RuleActiveRegressionNode)defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); //other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + this.ruleSet.add(new PassiveRule(this.defaultRule)); + // send to statistics PI + sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); + defaultRule=newDefaultRule; + } + } + } + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and + * its prediction result. + * @param prediction The predicted class label from the decision tree model. + * @param inEvent The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){ + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + public ErrorWeightedVote newErrorWeightedVote() { + if (voteType == 1) + return new UniformWeightedVote(); + return new InverseErrorWeightedVote(); + } + + /** + * Method to verify if the instance is an anomaly. + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, LearningRule rule) { + //AMRUles is equipped with anomaly detection. If on, compute the anomaly value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false){ + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Create new rules + */ + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r=newRule(ID); + + if (node!=null) + { + if(node.getPerceptron()!=null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics==null) + { + double mean; + if(node.getNodeStatistics().getValue(0)>0){ + mean=node.getNodeStatistics().getValue(1)/node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics!=null && ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean()!=null) + { + double mean; + if(statistics[0]>0){ + mean=statistics[1]/statistics[0]; + ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean().reset(mean, (long)statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r=new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + /* + * Add predicate/RuleSplitNode for a rule + */ + private void updateRuleSplitNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule:ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + if (pce.getRuleSplitNode() != null) + rule.nodeListAdd(pce.getRuleSplitNode()); + if (pce.getLearningNode() != null) + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + /* + * Remove rule + */ + private void removeRule(int ruleID) { + for (PassiveRule rule:ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + ruleSet.remove(rule); + break; + } + } + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics= new double[]{0.0,0,0}; + this.ruleNumberID=0; + this.defaultRule = newRule(++this.ruleNumberID); + + this.ruleSet = new LinkedList<PassiveRule>(); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRulesAggregatorProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.statisticsStream = oldProcessor.statisticsStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendInstanceToRule(Instance instance, int ruleID) { + AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); + this.statisticsStream.put(ace); + } + + + + private void sendAddRuleEvent(int ruleID, ActiveRule rule) { + RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); + this.statisticsStream.put(rce); + } + + /* + * Output streams + */ + public void setStatisticsStream(Stream statisticsStream) { + this.statisticsStream = statisticsStream; + } + + public Stream getStatisticsStream() { + return this.statisticsStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Others + */ + public boolean isRandomizable() { + return true; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private int voteType; + + private Instances dataset; + + public Builder(Instances dataset){ + this.dataset = dataset; + } + + public Builder(AMRulesAggregatorProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.numericObserver = processor.numericObserver; + this.voteType = processor.voteType; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder voteType(int voteType) { + this.voteType = voteType; + return this; + } + + public AMRulesAggregatorProcessor build() { + return new AMRulesAggregatorProcessor(this); + } + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java new file mode 100644 index 0000000..da820d8 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java @@ -0,0 +1,220 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleSplitNode; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * Learner Processor (VAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRulesStatisticsProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 5268933189695395573L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesStatisticsProcessor.class); + + private int processorId; + + private transient List<ActiveRule> ruleSet; + + private Stream outputStream; + + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + public AMRulesStatisticsProcessor(Builder builder) { + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + this.frequency = builder.frequency; + } + + @Override + public boolean process(ContentEvent event) { + if (event instanceof AssignmentContentEvent) { + + AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; + trainRuleOnInstance(attrContentEvent.getRuleNumberID(),attrContentEvent.getInstance()); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent ruleContentEvent = (RuleContentEvent) event; + if (!ruleContentEvent.isRemoving()) { + addRule(ruleContentEvent.getRule()); + } + } + + return false; + } + + /* + * Process input instances + */ + private void trainRuleOnInstance(int ruleID, Instance instance) { + Iterator<ActiveRule> ruleIterator= this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.getRuleNumberID() == ruleID) { + // Check (again) for coverage + // Skip anomaly check as Aggregator's perceptron should be well-updated + if (rule.isCovering(instance) == true) { + double error = rule.computeError(instance); //Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode)rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + + this.sendRemoveRuleEvent(ruleID); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold) ) { + rule.split(); + + // expanded: update Aggregator with new/updated predicate + this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), + (RuleActiveRegressionNode)rule.getLearningNode()); + } + } + } + } + + return; + } + } + } + + private void sendRemoveRuleEvent(int ruleID) { + RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); + this.outputStream.put(rce); + } + + private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { + this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); + } + + /* + * Process control message (regarding adding or removing rules) + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(rule); + return true; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList<ActiveRule>(); + } + + @Override + public Processor newProcessor(Processor p) { + AMRulesStatisticsProcessor oldProcessor = (AMRulesStatisticsProcessor)p; + AMRulesStatisticsProcessor newProcessor = + new AMRulesStatisticsProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + /* + * Builder + */ + public static class Builder { + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + private Instances dataset; + + public Builder(Instances dataset){ + this.dataset = dataset; + } + + public Builder(AMRulesStatisticsProcessor processor) { + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + this.frequency = processor.frequency; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder frequency(int frequency) { + this.frequency = frequency; + return this; + } + + public AMRulesStatisticsProcessor build() { + return new AMRulesStatisticsProcessor(this); + } + } + + /* + * Output stream + */ + public void setOutputStream(Stream stream) { + this.outputStream = stream; + } + + public Stream getOutputStream() { + return this.outputStream; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java new file mode 100644 index 0000000..5a03406 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java @@ -0,0 +1,74 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Forwarded instances from Model Agrregator to Learners/Default Rule Learner. + * + * @author Anh Thu Vu + * + */ +public class AssignmentContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 1031695762172836629L; + + private int ruleNumberID; + private Instance instance; + + public AssignmentContentEvent() { + this(0, null); + } + + public AssignmentContentEvent(int ruleID, Instance instance) { + this.ruleNumberID = ruleID; + this.instance = instance; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } + + public Instance getInstance() { + return this.instance; + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java new file mode 100644 index 0000000..69e935a --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java @@ -0,0 +1,84 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleSplitNode; + +/** + * New features (of newly expanded rules) from Learners to Model Aggregators. + * + * @author Anh Thu Vu + * + */ +public class PredicateContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 7909435830443732451L; + + private int ruleNumberID; + private RuleSplitNode ruleSplitNode; + private RulePassiveRegressionNode learningNode; + + /* + * Constructor + */ + public PredicateContentEvent() { + this(0, null, null); + } + + public PredicateContentEvent (int ruleID, RuleSplitNode ruleSplitNode, RulePassiveRegressionNode learningNode) { + this.ruleNumberID = ruleID; + this.ruleSplitNode = ruleSplitNode; // is this is null: this is for updating learningNode only + this.learningNode = learningNode; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; // N/A + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + + public RuleSplitNode getRuleSplitNode() { + return this.ruleSplitNode; + } + + public RulePassiveRegressionNode getLearningNode() { + return this.learningNode; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java new file mode 100644 index 0000000..a9dab4a --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java @@ -0,0 +1,82 @@ +package com.yahoo.labs.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 - 2014 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; + +/** + * New rule from Model Aggregator/Default Rule Learner to Learners + * or removed rule from Learner to Model Aggregators. + * + * @author Anh Thu Vu + * + */ +public class RuleContentEvent implements ContentEvent { + + + /** + * + */ + private static final long serialVersionUID = -9046390274402894461L; + + private final int ruleNumberID; + private final ActiveRule addingRule; // for removing rule, we only need the rule's ID + private final boolean isRemoving; + + public RuleContentEvent() { + this(0, null, false); + } + + public RuleContentEvent(int ruleID, ActiveRule rule, boolean isRemoving) { + this.ruleNumberID = ruleID; + this.isRemoving = isRemoving; + this.addingRule = rule; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + + public ActiveRule getRule() { + return this.addingRule; + } + + public boolean isRemoving() { + return this.isRemoving; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java new file mode 100644 index 0000000..40d260c --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java @@ -0,0 +1,207 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.HashMap; +import java.util.Map; + +import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.instances.Instance; + +final class ActiveLearningNode extends LearningNode { + /** + * + */ + private static final long serialVersionUID = -2892102872646338908L; + private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class); + + private double weightSeenAtLastSplitEvaluation; + + private final Map<Integer, String> attributeContentEventKeys; + + private AttributeSplitSuggestion bestSuggestion; + private AttributeSplitSuggestion secondBestSuggestion; + + private final long id; + private final int parallelismHint; + private int suggestionCtr; + private int thrownAwayInstance; + + private boolean isSplitting; + + ActiveLearningNode(double[] classObservation, int parallelismHint) { + super(classObservation); + this.weightSeenAtLastSplitEvaluation = this.getWeightSeen(); + this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate(); + this.attributeContentEventKeys = new HashMap<>(); + this.isSplitting = false; + this.parallelismHint = parallelismHint; + } + + long getId(){ + return id; + } + + protected AttributeBatchContentEvent[] attributeBatchContentEvent; + + public AttributeBatchContentEvent[] getAttributeBatchContentEvent() { + return this.attributeBatchContentEvent; + } + + public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) { + this.attributeBatchContentEvent = attributeBatchContentEvent; + } + + @Override + void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { + //TODO: what statistics should we keep for unused instance? + if(isSplitting){ //currently throw all instance will splitting + this.thrownAwayInstance++; + return; + } + this.observedClassDistribution.addToValue((int)inst.classValue(), + inst.weight()); + //done: parallelize by sending attributes one by one + //TODO: meanwhile, we can try to use the ThreadPool to execute it separately + //TODO: parallelize by sending in batch, i.e. split the attributes into + //chunk instead of send the attribute one by one + for(int i = 0; i < inst.numAttributes() - 1; i++){ + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + Integer obsIndex = i; + String key = attributeContentEventKeys.get(obsIndex); + + if(key == null){ + key = this.generateKey(i); + attributeContentEventKeys.put(obsIndex, key); + } + AttributeContentEvent ace = new AttributeContentEvent.Builder( + this.id, i, key) + .attrValue(inst.value(instAttIndex)) + .classValue((int) inst.classValue()) + .weight(inst.weight()) + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + if (this.attributeBatchContentEvent == null){ + this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1]; + } + if (this.attributeBatchContentEvent[i] == null){ + this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder( + this.id, i, key) + //.attrValue(inst.value(instAttIndex)) + //.classValue((int) inst.classValue()) + //.weight(inst.weight()] + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + } + this.attributeBatchContentEvent[i].add(ace); + //proc.sendToAttributeStream(ace); + } + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { + return this.observedClassDistribution.getArrayCopy(); + } + + double getWeightSeen(){ + return this.observedClassDistribution.sumOfValues(); + } + + void setWeightSeenAtLastSplitEvaluation(double weight){ + this.weightSeenAtLastSplitEvaluation = weight; + } + + double getWeightSeenAtLastSplitEvaluation(){ + return this.weightSeenAtLastSplitEvaluation; + } + + void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) { + this.isSplitting = true; + this.suggestionCtr = 0; + this.thrownAwayInstance = 0; + + ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id, + this.getObservedClassDistribution()); + modelAggrProc.sendToControlStream(cce); + } + + void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion){ + //starts comparing from the best suggestion + if(bestSuggestion != null){ + if((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)){ + this.secondBestSuggestion = this.bestSuggestion; + this.bestSuggestion = bestSuggestion; + + if(secondBestSuggestion != null){ + + if((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)){ + this.secondBestSuggestion = secondBestSuggestion; + } + } + }else{ + if((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)){ + this.secondBestSuggestion = bestSuggestion; + } + } + } + + //TODO: optimize the code to use less memory + this.suggestionCtr++; + } + + boolean isSplitting(){ + return this.isSplitting; + } + + void endSplitting(){ + this.isSplitting = false; + logger.trace("wasted instance: {}", this.thrownAwayInstance); + this.thrownAwayInstance = 0; + } + + AttributeSplitSuggestion getDistributedBestSuggestion(){ + return this.bestSuggestion; + } + + AttributeSplitSuggestion getDistributedSecondBestSuggestion(){ + return this.secondBestSuggestion; + } + + boolean isAllSuggestionsCollected(){ + return (this.suggestionCtr == this.parallelismHint); + } + + private static int modelAttIndexToInstanceAttIndex(int index, Instance inst){ + return inst.classIndex() > index ? index : index + 1; + } + + private String generateKey(int obsIndex){ + final int prime = 31; + int result = 1; + result = prime * result + (int) (this.id ^ (this.id >>> 32)); + result = prime * result + obsIndex; + return Integer.toString(result); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java new file mode 100644 index 0000000..691d0fb --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java @@ -0,0 +1,134 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.core.ContentEvent; +import java.util.LinkedList; +import java.util.List; + +/** + * Attribute Content Event represents the instances that split vertically + * based on their attribute + * @author Arinto Murdopo + * + */ +final class AttributeBatchContentEvent implements ContentEvent { + + private static final long serialVersionUID = 6652815649846676832L; + + private final long learningNodeId; + private final int obsIndex; + private final List<ContentEvent> contentEventList; + private final transient String key; + private final boolean isNominal; + + public AttributeBatchContentEvent(){ + learningNodeId = -1; + obsIndex = -1; + contentEventList = new LinkedList<>(); + key = ""; + isNominal = true; + } + + private AttributeBatchContentEvent(Builder builder){ + this.learningNodeId = builder.learningNodeId; + this.obsIndex = builder.obsIndex; + this.contentEventList = new LinkedList<>(); + if (builder.contentEvent != null) { + this.contentEventList.add(builder.contentEvent); + } + this.isNominal = builder.isNominal; + this.key = builder.key; + } + + public void add(ContentEvent contentEvent){ + this.contentEventList.add(contentEvent); + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + //do nothing, maybe useful when we want to reuse the object for serialization/deserialization purpose + } + + @Override + public boolean isLastEvent() { + return false; + } + + long getLearningNodeId(){ + return this.learningNodeId; + } + + int getObsIndex(){ + return this.obsIndex; + } + + public List<ContentEvent> getContentEventList(){ + return this.contentEventList; + } + + boolean isNominal(){ + return this.isNominal; + } + + static final class Builder{ + + //required parameters + private final long learningNodeId; + private final int obsIndex; + private final String key; + + private ContentEvent contentEvent; + private boolean isNominal = false; + + Builder(long id, int obsIndex, String key){ + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = key; + } + + private Builder(long id, int obsIndex){ + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = ""; + } + + Builder contentEvent(ContentEvent contentEvent){ + this.contentEvent = contentEvent; + return this; + } + + Builder isNominal(boolean val){ + this.isNominal = val; + return this; + } + + AttributeBatchContentEvent build(){ + return new AttributeBatchContentEvent(this); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java new file mode 100644 index 0000000..4cbdd95 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java @@ -0,0 +1,222 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import com.yahoo.labs.samoa.core.ContentEvent; + +/** + * Attribute Content Event represents the instances that split vertically + * based on their attribute + * @author Arinto Murdopo + * + */ +public final class AttributeContentEvent implements ContentEvent { + + private static final long serialVersionUID = 6652815649846676832L; + + private final long learningNodeId; + private final int obsIndex; + private final double attrVal; + private final int classVal; + private final double weight; + private final transient String key; + private final boolean isNominal; + + public AttributeContentEvent(){ + learningNodeId = -1; + obsIndex = -1; + attrVal = 0.0; + classVal = -1; + weight = 0.0; + key = ""; + isNominal = true; + } + + private AttributeContentEvent(Builder builder){ + this.learningNodeId = builder.learningNodeId; + this.obsIndex = builder.obsIndex; + this.attrVal = builder.attrVal; + this.classVal = builder.classVal; + this.weight = builder.weight; + this.isNominal = builder.isNominal; + this.key = builder.key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + //do nothing, maybe useful when we want to reuse the object for serialization/deserialization purpose + } + + @Override + public boolean isLastEvent() { + return false; + } + + long getLearningNodeId(){ + return this.learningNodeId; + } + + int getObsIndex(){ + return this.obsIndex; + } + + int getClassVal(){ + return this.classVal; + } + + double getAttrVal(){ + return this.attrVal; + } + + double getWeight(){ + return this.weight; + } + + boolean isNominal(){ + return this.isNominal; + } + + static final class Builder{ + + //required parameters + private final long learningNodeId; + private final int obsIndex; + private final String key; + + //optional parameters + private double attrVal = 0.0; + private int classVal = 0; + private double weight = 0.0; + private boolean isNominal = false; + + Builder(long id, int obsIndex, String key){ + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = key; + } + + private Builder(long id, int obsIndex){ + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = ""; + } + + Builder attrValue(double val){ + this.attrVal = val; + return this; + } + + Builder classValue(int val){ + this.classVal = val; + return this; + } + + Builder weight(double val){ + this.weight = val; + return this; + } + + Builder isNominal(boolean val){ + this.isNominal = val; + return this; + } + + AttributeContentEvent build(){ + return new AttributeContentEvent(this); + } + } + + /** + * The Kryo serializer class for AttributeContentEvent when executing on top of Storm. + * This class allow us to change the precision of the statistics. + * @author Arinto Murdopo + * + */ + public static final class AttributeCESerializer extends Serializer<AttributeContentEvent>{ + + private static double PRECISION = 1000000.0; + @Override + public void write(Kryo kryo, Output output, AttributeContentEvent event) { + output.writeLong(event.learningNodeId, true); + output.writeInt(event.obsIndex, true); + output.writeDouble(event.attrVal, PRECISION, true); + output.writeInt(event.classVal, true); + output.writeDouble(event.weight, PRECISION, true); + output.writeBoolean(event.isNominal); + } + + @Override + public AttributeContentEvent read(Kryo kryo, Input input, + Class<AttributeContentEvent> type) { + AttributeContentEvent ace + = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) + .attrValue(input.readDouble(PRECISION, true)) + .classValue(input.readInt(true)) + .weight(input.readDouble(PRECISION, true)) + .isNominal(input.readBoolean()) + .build(); + return ace; + } + } + + /** + * The Kryo serializer class for AttributeContentEvent when executing on top of Storm + * with full precision of the statistics. + * @author Arinto Murdopo + * + */ + public static final class AttributeCEFullPrecSerializer extends Serializer<AttributeContentEvent>{ + + @Override + public void write(Kryo kryo, Output output, AttributeContentEvent event) { + output.writeLong(event.learningNodeId, true); + output.writeInt(event.obsIndex, true); + output.writeDouble(event.attrVal); + output.writeInt(event.classVal, true); + output.writeDouble(event.weight); + output.writeBoolean(event.isNominal); + } + + @Override + public AttributeContentEvent read(Kryo kryo, Input input, + Class<AttributeContentEvent> type) { + AttributeContentEvent ace + = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) + .attrValue(input.readDouble()) + .classValue(input.readInt(true)) + .weight(input.readDouble()) + .isNominal(input.readBoolean()) + .build(); + return ace; + } + + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java new file mode 100644 index 0000000..52f4685 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java @@ -0,0 +1,142 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +/** + * Compute content event is the message that is sent by Model Aggregator Processor + * to request Local Statistic PI to start the local statistic calculation for splitting + * @author Arinto Murdopo + * + */ +public final class ComputeContentEvent extends ControlContentEvent { + + private static final long serialVersionUID = 5590798490073395190L; + + private final double[] preSplitDist; + private final long splitId; + + public ComputeContentEvent(){ + super(-1); + preSplitDist = null; + splitId = -1; + } + + ComputeContentEvent(long splitId, long id, double[] preSplitDist) { + super(id); + //this.preSplitDist = Arrays.copyOf(preSplitDist, preSplitDist.length); + this.preSplitDist = preSplitDist; + this.splitId = splitId; + } + + @Override + LocStatControl getType() { + return LocStatControl.COMPUTE; + } + + double[] getPreSplitDist(){ + return this.preSplitDist; + } + + long getSplitId(){ + return this.splitId; + } + + /** + * The Kryo serializer class for ComputeContentEevent when executing on top of Storm. + * This class allow us to change the precision of the statistics. + * @author Arinto Murdopo + * + */ + public static final class ComputeCESerializer extends Serializer<ComputeContentEvent>{ + + private static double PRECISION = 1000000.0; + + @Override + public void write(Kryo kryo, Output output, ComputeContentEvent object) { + output.writeLong(object.splitId, true); + output.writeLong(object.learningNodeId, true); + + output.writeInt(object.preSplitDist.length, true); + for(int i = 0; i < object.preSplitDist.length; i++){ + output.writeDouble(object.preSplitDist[i], PRECISION, true); + } + } + + @Override + public ComputeContentEvent read(Kryo kryo, Input input, + Class<ComputeContentEvent> type) { + long splitId = input.readLong(true); + long learningNodeId = input.readLong(true); + + int dataLength = input.readInt(true); + double[] preSplitDist = new double[dataLength]; + + for(int i = 0; i < dataLength; i++){ + preSplitDist[i] = input.readDouble(PRECISION, true); + } + + return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); + } + } + + /** + * The Kryo serializer class for ComputeContentEevent when executing on top of Storm + * with full precision of the statistics. + * @author Arinto Murdopo + * + */ + public static final class ComputeCEFullPrecSerializer extends Serializer<ComputeContentEvent>{ + + @Override + public void write(Kryo kryo, Output output, ComputeContentEvent object) { + output.writeLong(object.splitId, true); + output.writeLong(object.learningNodeId, true); + + output.writeInt(object.preSplitDist.length, true); + for(int i = 0; i < object.preSplitDist.length; i++){ + output.writeDouble(object.preSplitDist[i]); + } + } + + @Override + public ComputeContentEvent read(Kryo kryo, Input input, + Class<ComputeContentEvent> type) { + long splitId = input.readLong(true); + long learningNodeId = input.readLong(true); + + int dataLength = input.readInt(true); + double[] preSplitDist = new double[dataLength]; + + for(int i = 0; i < dataLength; i++){ + preSplitDist[i] = input.readDouble(); + } + + return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); + } + + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java new file mode 100644 index 0000000..201ef88 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ControlContentEvent.java @@ -0,0 +1,71 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.core.ContentEvent; + +/** + * Abstract class to represent ContentEvent to control Local Statistic Processor. + * @author Arinto Murdopo + * + */ +abstract class ControlContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 5837375639629708363L; + + protected final long learningNodeId; + + public ControlContentEvent(){ + this.learningNodeId = -1; + } + + ControlContentEvent(long id){ + this.learningNodeId = id; + } + + @Override + public final String getKey() { + return null; + } + + @Override + public void setKey(String str){ + //Do nothing + } + + @Override + public boolean isLastEvent(){ + return false; + } + + final long getLearningNodeId(){ + return this.learningNodeId; + } + + abstract LocStatControl getType(); + + static enum LocStatControl { + COMPUTE, DELETE + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java new file mode 100644 index 0000000..c721255 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/DeleteContentEvent.java @@ -0,0 +1,45 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * Delete Content Event is the content event that is sent by Model Aggregator Processor + * to delete unnecessary statistic in Local Statistic Processor. + * @author Arinto Murdopo + * + */ +final class DeleteContentEvent extends ControlContentEvent { + + private static final long serialVersionUID = -2105250722560863633L; + + public DeleteContentEvent(){ + super(-1); + } + + DeleteContentEvent(long id) { + super(id); } + + @Override + LocStatControl getType() { + return LocStatControl.DELETE; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java new file mode 100644 index 0000000..b6a73c6 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FilterProcessor.java @@ -0,0 +1,191 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.learners.InstanceContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.instances.InstancesHeader; +import com.yahoo.labs.samoa.learners.InstancesContentEvent; +import com.yahoo.labs.samoa.topology.Stream; +import java.util.LinkedList; +import java.util.List; + +/** + * Filter Processor that stores and filters the instances before + * sending them to the Model Aggregator Processor. + + * @author Arinto Murdopo + * + */ +final class FilterProcessor implements Processor { + + private static final long serialVersionUID = -1685875718300564885L; + private static final Logger logger = LoggerFactory.getLogger(FilterProcessor.class); + + private int processorId; + + private final Instances dataset; + private InstancesHeader modelContext; + + //available streams + private Stream outputStream; + + //private constructor based on Builder pattern + private FilterProcessor(Builder builder){ + this.dataset = builder.dataset; + this.batchSize = builder.batchSize; + this.delay = builder.delay; + } + + private int waitingInstances = 0; + + private int delay = 0; + + private int batchSize = 200; + + private List<InstanceContentEvent> contentEventList = new LinkedList<InstanceContentEvent>(); + + @Override + public boolean process(ContentEvent event) { + //Receive a new instance from source + if(event instanceof InstanceContentEvent){ + InstanceContentEvent instanceContentEvent = (InstanceContentEvent) event; + this.contentEventList.add(instanceContentEvent); + this.waitingInstances++; + if (this.waitingInstances == this.batchSize || instanceContentEvent.isLastEvent()){ + //Send Instances + InstancesContentEvent outputEvent = new InstancesContentEvent(instanceContentEvent); + boolean isLastEvent = false; + while (!this.contentEventList.isEmpty()){ + InstanceContentEvent ice = this.contentEventList.remove(0); + Instance inst = ice.getInstance(); + outputEvent.add(inst); + if (!isLastEvent) { + isLastEvent = ice.isLastEvent(); + } + } + outputEvent.setLast(isLastEvent); + this.waitingInstances = 0; + this.outputStream.put(outputEvent); + if (this.delay > 0) { + try { + Thread.sleep(this.delay); + } catch(InterruptedException ex) { + Thread.currentThread().interrupt(); + } + } + } + } + return false; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.waitingInstances = 0; + + } + + @Override + public Processor newProcessor(Processor p) { + FilterProcessor oldProcessor = (FilterProcessor)p; + FilterProcessor newProcessor = + new FilterProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()); + return sb.toString(); + } + + void setOutputStream(Stream outputStream){ + this.outputStream = outputStream; + } + + + /** + * Helper method to generate new ResultContentEvent based on an instance and + * its prediction result. + * @param prediction The predicted class label from the decision tree model. + * @param inEvent The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){ + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + + /** + * Builder class to replace constructors with many parameters + * @author Arinto Murdopo + * + */ + static class Builder{ + + //required parameters + private final Instances dataset; + + private int delay = 0; + + private int batchSize = 200; + + Builder(Instances dataset){ + this.dataset = dataset; + } + + Builder(FilterProcessor oldProcessor){ + this.dataset = oldProcessor.dataset; + this.delay = oldProcessor.delay; + this.batchSize = oldProcessor.batchSize; + } + + public Builder delay(int delay){ + this.delay = delay; + return this; + } + + public Builder batchSize(int val){ + this.batchSize = val; + return this; + } + + FilterProcessor build(){ + return new FilterProcessor(this); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java new file mode 100644 index 0000000..4123ea5 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/FoundNode.java @@ -0,0 +1,77 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * Class that represents the necessary data structure of the node where an instance + * is routed/filtered through the decision tree model. + * + * @author Arinto Murdopo + * + */ +final class FoundNode implements java.io.Serializable{ + + /** + * + */ + private static final long serialVersionUID = -637695387934143293L; + + private final Node node; + private final SplitNode parent; + private final int parentBranch; + + FoundNode(Node node, SplitNode splitNode, int parentBranch){ + this.node = node; + this.parent = splitNode; + this.parentBranch = parentBranch; + } + + /** + * Method to get the node where an instance is routed/filtered through the decision tree + * model for testing and training. + * + * @return The node where the instance is routed/filtered + */ + Node getNode(){ + return this.node; + } + + /** + * Method to get the parent of the node where an instance is routed/filtered through the decision tree + * model for testing and training + * + * @return The parent of the node + */ + SplitNode getParent(){ + return this.parent; + } + + /** + * Method to get the index of the node (where an instance is routed/filtered through the decision tree + * model for testing and training) in its parent. + * + * @return The index of the node in its parent node. + */ + int getParentBranch(){ + return this.parentBranch; + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java new file mode 100644 index 0000000..82a05de --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/InactiveLearningNode.java @@ -0,0 +1,56 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Class that represents inactive learning node. Inactive learning node is + * a node which only keeps track of the observed class distribution. It does + * not store the statistic for splitting the node. + * + * @author Arinto Murdopo + * + */ +final class InactiveLearningNode extends LearningNode { + + /** + * + */ + private static final long serialVersionUID = -814552382883472302L; + + + InactiveLearningNode(double[] initialClassObservation) { + super(initialClassObservation); + } + + @Override + void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { + this.observedClassDistribution.addToValue( + (int)inst.classValue(), inst.weight()); + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { + return this.observedClassDistribution.getArrayCopy(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java new file mode 100644 index 0000000..58de671 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LearningNode.java @@ -0,0 +1,55 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Abstract class that represents a learning node + * @author Arinto Murdopo + * + */ +abstract class LearningNode extends Node { + + private static final long serialVersionUID = 7157319356146764960L; + + protected LearningNode(double[] classObservation) { + super(classObservation); + } + + /** + * Method to process the instance for learning + * @param inst The processed instance + * @param proc The model aggregator processor where this learning node exists + */ + abstract void learnFromInstance(Instance inst, ModelAggregatorProcessor proc); + + @Override + protected boolean isLeaf(){ + return true; + } + + @Override + protected FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, + int parentBranch) { + return new FoundNode(this, parent, parentBranch); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java new file mode 100644 index 0000000..142d28a --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalResultContentEvent.java @@ -0,0 +1,92 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import com.yahoo.labs.samoa.core.ContentEvent; + +/** + * Local Result Content Event is the content event that represents local + * calculation of statistic in Local Statistic Processor. + * + * @author Arinto Murdopo + * + */ +final class LocalResultContentEvent implements ContentEvent{ + + private static final long serialVersionUID = -4206620993777418571L; + + private final AttributeSplitSuggestion bestSuggestion; + private final AttributeSplitSuggestion secondBestSuggestion; + private final long splitId; + + public LocalResultContentEvent(){ + bestSuggestion = null; + secondBestSuggestion = null; + splitId = -1; + } + + LocalResultContentEvent(long splitId, AttributeSplitSuggestion best, AttributeSplitSuggestion secondBest){ + this.splitId = splitId; + this.bestSuggestion = best; + this.secondBestSuggestion = secondBest; + } + + @Override + public String getKey() { + return null; + } + + /** + * Method to return the best attribute split suggestion from this local statistic calculation. + * @return The best attribute split suggestion. + */ + AttributeSplitSuggestion getBestSuggestion(){ + return this.bestSuggestion; + } + + /** + * Method to return the second best attribute split suggestion from this local statistic calculation. + * @return The second best attribute split suggestion. + */ + AttributeSplitSuggestion getSecondBestSuggestion(){ + return this.secondBestSuggestion; + } + + /** + * Method to get the split ID of this local statistic calculation result + * @return The split id of this local calculation result + */ + long getSplitId(){ + return this.splitId; + } + + @Override + public void setKey(String str) { + //do nothing + + } + + @Override + public boolean isLastEvent() { + return false; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java new file mode 100644 index 0000000..25e5592 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/LocalStatisticsProcessor.java @@ -0,0 +1,246 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Vector; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.GaussianNumericAttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.NominalAttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion; +import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; + +import com.google.common.collect.HashBasedTable; +import com.google.common.collect.Table; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * Local Statistic Processor contains the local statistic of a subset of the attributes. + * @author Arinto Murdopo + * + */ +final class LocalStatisticsProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -3967695130634517631L; + private static Logger logger = LoggerFactory.getLogger(LocalStatisticsProcessor.class); + + //Collection of AttributeObservers, for each ActiveLearningNode and AttributeId + private Table<Long, Integer, AttributeClassObserver> localStats; + + private Stream computationResultStream; + + private final SplitCriterion splitCriterion; + private final boolean binarySplit; + private final AttributeClassObserver nominalClassObserver; + private final AttributeClassObserver numericClassObserver; + + //the two observer classes below are also needed to be setup from the Tree + private LocalStatisticsProcessor(Builder builder){ + this.splitCriterion = builder.splitCriterion; + this.binarySplit = builder.binarySplit; + this.nominalClassObserver = builder.nominalClassObserver; + this.numericClassObserver = builder.numericClassObserver; + } + + @Override + public boolean process(ContentEvent event) { + //process AttributeContentEvent by updating the subset of local statistics + if (event instanceof AttributeBatchContentEvent) { + AttributeBatchContentEvent abce = (AttributeBatchContentEvent) event; + List<ContentEvent> contentEventList = abce.getContentEventList(); + for (ContentEvent contentEvent: contentEventList ){ + AttributeContentEvent ace = (AttributeContentEvent) contentEvent; + Long learningNodeId = ace.getLearningNodeId(); + Integer obsIndex = ace.getObsIndex(); + + AttributeClassObserver obs = localStats.get( + learningNodeId, obsIndex); + + if (obs == null) { + obs = ace.isNominal() ? newNominalClassObserver() + : newNumericClassObserver(); + localStats.put(ace.getLearningNodeId(), obsIndex, obs); + } + obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(), + ace.getWeight()); + } + + + /*if (event instanceof AttributeContentEvent) { + AttributeContentEvent ace = (AttributeContentEvent) event; + Long learningNodeId = Long.valueOf(ace.getLearningNodeId()); + Integer obsIndex = Integer.valueOf(ace.getObsIndex()); + + AttributeClassObserver obs = localStats.get( + learningNodeId, obsIndex); + + if (obs == null) { + obs = ace.isNominal() ? newNominalClassObserver() + : newNumericClassObserver(); + localStats.put(ace.getLearningNodeId(), obsIndex, obs); + } + obs.observeAttributeClass(ace.getAttrVal(), ace.getClassVal(), + ace.getWeight()); + */ + } else if (event instanceof ComputeContentEvent) { + //process ComputeContentEvent by calculating the local statistic + //and send back the calculation results via computation result stream. + ComputeContentEvent cce = (ComputeContentEvent) event; + Long learningNodeId = cce.getLearningNodeId(); + double[] preSplitDist = cce.getPreSplitDist(); + + Map<Integer, AttributeClassObserver> learningNodeRowMap = localStats + .row(learningNodeId); + List<AttributeSplitSuggestion> suggestions = new Vector<>(); + + for (Entry<Integer, AttributeClassObserver> entry : learningNodeRowMap.entrySet()) { + AttributeClassObserver obs = entry.getValue(); + AttributeSplitSuggestion suggestion = obs + .getBestEvaluatedSplitSuggestion(splitCriterion, + preSplitDist, entry.getKey(), binarySplit); + if(suggestion != null){ + suggestions.add(suggestion); + } + } + + AttributeSplitSuggestion[] bestSuggestions = suggestions + .toArray(new AttributeSplitSuggestion[suggestions.size()]); + + Arrays.sort(bestSuggestions); + + AttributeSplitSuggestion bestSuggestion = null; + AttributeSplitSuggestion secondBestSuggestion = null; + + if (bestSuggestions.length >= 1){ + bestSuggestion = bestSuggestions[bestSuggestions.length - 1]; + + if(bestSuggestions.length >= 2){ + secondBestSuggestion = bestSuggestions[bestSuggestions.length - 2]; + } + } + + //create the local result content event + LocalResultContentEvent lcre = + new LocalResultContentEvent(cce.getSplitId(), bestSuggestion, secondBestSuggestion); + computationResultStream.put(lcre); + logger.debug("Finish compute event"); + } else if (event instanceof DeleteContentEvent) { + DeleteContentEvent dce = (DeleteContentEvent) event; + Long learningNodeId = dce.getLearningNodeId(); + localStats.rowMap().remove(learningNodeId); + } + return false; + } + + @Override + public void onCreate(int id) { + this.localStats = HashBasedTable.create(); + } + + @Override + public Processor newProcessor(Processor p) { + LocalStatisticsProcessor oldProcessor = (LocalStatisticsProcessor) p; + LocalStatisticsProcessor newProcessor + = new LocalStatisticsProcessor.Builder(oldProcessor).build(); + + newProcessor.setComputationResultStream(oldProcessor.computationResultStream); + + return newProcessor; + } + + /** + * Method to set the computation result when using this processor to build + * a topology. + * @param computeStream + */ + void setComputationResultStream(Stream computeStream){ + this.computationResultStream = computeStream; + } + + private AttributeClassObserver newNominalClassObserver() { + return (AttributeClassObserver)this.nominalClassObserver.copy(); + } + + private AttributeClassObserver newNumericClassObserver() { + return (AttributeClassObserver)this.numericClassObserver.copy(); + } + + /** + * Builder class to replace constructors with many parameters + * @author Arinto Murdopo + * + */ + static class Builder{ + + private SplitCriterion splitCriterion = new InfoGainSplitCriterion(); + private boolean binarySplit = false; + private AttributeClassObserver nominalClassObserver = new NominalAttributeClassObserver(); + private AttributeClassObserver numericClassObserver = new GaussianNumericAttributeClassObserver(); + + Builder(){ + + } + + Builder(LocalStatisticsProcessor oldProcessor){ + this.splitCriterion = oldProcessor.splitCriterion; + this.binarySplit = oldProcessor.binarySplit; + } + + Builder splitCriterion(SplitCriterion splitCriterion){ + this.splitCriterion = splitCriterion; + return this; + } + + Builder binarySplit(boolean binarySplit){ + this.binarySplit = binarySplit; + return this; + } + + Builder nominalClassObserver(AttributeClassObserver nominalClassObserver){ + this.nominalClassObserver = nominalClassObserver; + return this; + } + + Builder numericClassObserver(AttributeClassObserver numericClassObserver){ + this.numericClassObserver = numericClassObserver; + return this; + } + + LocalStatisticsProcessor build(){ + return new LocalStatisticsProcessor(this); + } + } + +}
