http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java index debe912..5cb4f01 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java @@ -48,478 +48,487 @@ import com.yahoo.labs.samoa.topology.Stream; * Model Aggregator Processor (VAMR). * * @author Anh Thu Vu - * + * */ public class AMRulesAggregatorProcessor implements Processor { - /** + /** * */ - private static final long serialVersionUID = 6303385725332704251L; - - private static final Logger logger = - LoggerFactory.getLogger(AMRulesAggregatorProcessor.class); - - private int processorId; - - // Rules & default rule - protected transient List<PassiveRule> ruleSet; - protected transient ActiveRule defaultRule; - protected transient int ruleNumberID; - protected transient double[] statistics; - - // SAMOA Stream - private Stream statisticsStream; - private Stream resultStream; - - // Options - protected int pageHinckleyThreshold; - protected double pageHinckleyAlpha; - protected boolean driftDetection; - protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 - protected boolean constantLearningRatioDecay; - protected double learningRatio; - - protected double splitConfidence; - protected double tieThreshold; - protected int gracePeriod; - - protected boolean noAnomalyDetection; - protected double multivariateAnomalyProbabilityThreshold; - protected double univariateAnomalyprobabilityThreshold; - protected int anomalyNumInstThreshold; - - protected boolean unorderedRules; - - protected FIMTDDNumericAttributeClassLimitObserver numericObserver; - protected int voteType; - - /* - * Constructor - */ - public AMRulesAggregatorProcessor (Builder builder) { - this.pageHinckleyThreshold = builder.pageHinckleyThreshold; - this.pageHinckleyAlpha = builder.pageHinckleyAlpha; - this.driftDetection = builder.driftDetection; - this.predictionFunction = builder.predictionFunction; - this.constantLearningRatioDecay = builder.constantLearningRatioDecay; - this.learningRatio = builder.learningRatio; - this.splitConfidence = builder.splitConfidence; - this.tieThreshold = builder.tieThreshold; - this.gracePeriod = builder.gracePeriod; - - this.noAnomalyDetection = builder.noAnomalyDetection; - this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; - this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; - this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; - this.unorderedRules = builder.unorderedRules; - - this.numericObserver = builder.numericObserver; - this.voteType = builder.voteType; - } - - /* - * Process - */ - @Override - public boolean process(ContentEvent event) { - if (event instanceof InstanceContentEvent) { - InstanceContentEvent instanceEvent = (InstanceContentEvent) event; - this.processInstanceEvent(instanceEvent); - } - else if (event instanceof PredicateContentEvent) { - this.updateRuleSplitNode((PredicateContentEvent) event); - } - else if (event instanceof RuleContentEvent) { - RuleContentEvent rce = (RuleContentEvent) event; - if (rce.isRemoving()) { - this.removeRule(rce.getRuleNumberID()); - } - } - - return true; - } - - // Merge predict and train so we only check for covering rules one time - private void processInstanceEvent(InstanceContentEvent instanceEvent) { - Instance instance = instanceEvent.getInstance(); - boolean predictionCovered = false; - boolean trainingCovered = false; - boolean continuePrediction = instanceEvent.isTesting(); - boolean continueTraining = instanceEvent.isTraining(); - - ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); - Iterator<PassiveRule> ruleIterator= this.ruleSet.iterator(); - while (ruleIterator.hasNext()) { - if (!continuePrediction && !continueTraining) - break; - - PassiveRule rule = ruleIterator.next(); - - if (rule.isCovering(instance) == true){ - predictionCovered = true; - - if (continuePrediction) { - double [] vote=rule.getPrediction(instance); - double error= rule.getCurrentError(); - errorWeightedVote.addVote(vote,error); - if (!this.unorderedRules) continuePrediction = false; - } - - if (continueTraining) { - if (!isAnomaly(instance, rule)) { - trainingCovered = true; - rule.updateStatistics(instance); - // Send instance to statistics PIs - sendInstanceToRule(instance, rule.getRuleNumberID()); - - if (!this.unorderedRules) continueTraining = false; - } - } - } - } - - if (predictionCovered) { - // Combined prediction - ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); - resultStream.put(rce); - } - else if (instanceEvent.isTesting()) { - // predict with default rule - double [] vote=defaultRule.getPrediction(instance); - ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); - resultStream.put(rce); - } - - if (!trainingCovered && instanceEvent.isTraining()) { - // train default rule with this instance - defaultRule.updateStatistics(instance); - if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { - if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { - ActiveRule newDefaultRule=newRule(defaultRule.getRuleNumberID(),(RuleActiveRegressionNode)defaultRule.getLearningNode(), - ((RuleActiveRegressionNode)defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); //other branch - defaultRule.split(); - defaultRule.setRuleNumberID(++ruleNumberID); - this.ruleSet.add(new PassiveRule(this.defaultRule)); - // send to statistics PI - sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); - defaultRule=newDefaultRule; - } - } - } - } - - /** - * Helper method to generate new ResultContentEvent based on an instance and - * its prediction result. - * @param prediction The predicted class label from the decision tree model. - * @param inEvent The associated instance content event - * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. - */ - private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){ - ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), inEvent.getClassId(), prediction, inEvent.isLastEvent()); - rce.setClassifierIndex(this.processorId); - rce.setEvaluationIndex(inEvent.getEvaluationIndex()); - return rce; - } - - public ErrorWeightedVote newErrorWeightedVote() { - if (voteType == 1) - return new UniformWeightedVote(); - return new InverseErrorWeightedVote(); - } - - /** - * Method to verify if the instance is an anomaly. - * @param instance - * @param rule - * @return - */ - private boolean isAnomaly(Instance instance, LearningRule rule) { - //AMRUles is equipped with anomaly detection. If on, compute the anomaly value. - boolean isAnomaly = false; - if (this.noAnomalyDetection == false){ - if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { - isAnomaly = rule.isAnomaly(instance, - this.univariateAnomalyprobabilityThreshold, - this.multivariateAnomalyProbabilityThreshold, - this.anomalyNumInstThreshold); - } - } - return isAnomaly; - } - - /* - * Create new rules - */ - private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { - ActiveRule r=newRule(ID); - - if (node!=null) - { - if(node.getPerceptron()!=null) - { - r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); - r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); - } - if (statistics==null) - { - double mean; - if(node.getNodeStatistics().getValue(0)>0){ - mean=node.getNodeStatistics().getValue(1)/node.getNodeStatistics().getValue(0); - r.getLearningNode().getTargetMean().reset(mean, 1); - } - } - } - if (statistics!=null && ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean()!=null) - { - double mean; - if(statistics[0]>0){ - mean=statistics[1]/statistics[0]; - ((RuleActiveRegressionNode)r.getLearningNode()).getTargetMean().reset(mean, (long)statistics[0]); - } - } - return r; - } - - private ActiveRule newRule(int ID) { - ActiveRule r=new ActiveRule.Builder(). - threshold(this.pageHinckleyThreshold). - alpha(this.pageHinckleyAlpha). - changeDetection(this.driftDetection). - predictionFunction(this.predictionFunction). - statistics(new double[3]). - learningRatio(this.learningRatio). - numericObserver(numericObserver). - id(ID).build(); - return r; - } - - /* - * Add predicate/RuleSplitNode for a rule - */ - private void updateRuleSplitNode(PredicateContentEvent pce) { - int ruleID = pce.getRuleNumberID(); - for (PassiveRule rule:ruleSet) { - if (rule.getRuleNumberID() == ruleID) { - if (pce.getRuleSplitNode() != null) - rule.nodeListAdd(pce.getRuleSplitNode()); - if (pce.getLearningNode() != null) - rule.setLearningNode(pce.getLearningNode()); - } - } - } - - /* - * Remove rule - */ - private void removeRule(int ruleID) { - for (PassiveRule rule:ruleSet) { - if (rule.getRuleNumberID() == ruleID) { - ruleSet.remove(rule); - break; - } - } - } - - @Override - public void onCreate(int id) { - this.processorId = id; - this.statistics= new double[]{0.0,0,0}; - this.ruleNumberID=0; - this.defaultRule = newRule(++this.ruleNumberID); - - this.ruleSet = new LinkedList<PassiveRule>(); - } - - /* - * Clone processor - */ - @Override - public Processor newProcessor(Processor p) { - AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p; - Builder builder = new Builder(oldProcessor); - AMRulesAggregatorProcessor newProcessor = builder.build(); - newProcessor.resultStream = oldProcessor.resultStream; - newProcessor.statisticsStream = oldProcessor.statisticsStream; - return newProcessor; - } - - /* - * Send events - */ - private void sendInstanceToRule(Instance instance, int ruleID) { - AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); - this.statisticsStream.put(ace); - } - - - - private void sendAddRuleEvent(int ruleID, ActiveRule rule) { - RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); - this.statisticsStream.put(rce); - } - - /* - * Output streams - */ - public void setStatisticsStream(Stream statisticsStream) { - this.statisticsStream = statisticsStream; - } - - public Stream getStatisticsStream() { - return this.statisticsStream; - } - - public void setResultStream(Stream resultStream) { - this.resultStream = resultStream; - } - - public Stream getResultStream() { - return this.resultStream; - } - - /* - * Others - */ - public boolean isRandomizable() { - return true; + private static final long serialVersionUID = 6303385725332704251L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesAggregatorProcessor.class); + + private int processorId; + + // Rules & default rule + protected transient List<PassiveRule> ruleSet; + protected transient ActiveRule defaultRule; + protected transient int ruleNumberID; + protected transient double[] statistics; + + // SAMOA Stream + private Stream statisticsStream; + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + protected int voteType; + + /* + * Constructor + */ + public AMRulesAggregatorProcessor(Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.numericObserver = builder.numericObserver; + this.voteType = builder.voteType; + } + + /* + * Process + */ + @Override + public boolean process(ContentEvent event) { + if (event instanceof InstanceContentEvent) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + this.processInstanceEvent(instanceEvent); } - - /* - * Builder - */ - public static class Builder { - private int pageHinckleyThreshold; - private double pageHinckleyAlpha; - private boolean driftDetection; - private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 - private boolean constantLearningRatioDecay; - private double learningRatio; - private double splitConfidence; - private double tieThreshold; - private int gracePeriod; - - private boolean noAnomalyDetection; - private double multivariateAnomalyProbabilityThreshold; - private double univariateAnomalyprobabilityThreshold; - private int anomalyNumInstThreshold; - - private boolean unorderedRules; - - private FIMTDDNumericAttributeClassLimitObserver numericObserver; - private int voteType; - - private Instances dataset; - - public Builder(Instances dataset){ - this.dataset = dataset; - } - - public Builder(AMRulesAggregatorProcessor processor) { - this.pageHinckleyThreshold = processor.pageHinckleyThreshold; - this.pageHinckleyAlpha = processor.pageHinckleyAlpha; - this.driftDetection = processor.driftDetection; - this.predictionFunction = processor.predictionFunction; - this.constantLearningRatioDecay = processor.constantLearningRatioDecay; - this.learningRatio = processor.learningRatio; - this.splitConfidence = processor.splitConfidence; - this.tieThreshold = processor.tieThreshold; - this.gracePeriod = processor.gracePeriod; - - this.noAnomalyDetection = processor.noAnomalyDetection; - this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; - this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; - this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; - this.unorderedRules = processor.unorderedRules; - - this.numericObserver = processor.numericObserver; - this.voteType = processor.voteType; - } - - public Builder threshold(int threshold) { - this.pageHinckleyThreshold = threshold; - return this; - } - - public Builder alpha(double alpha) { - this.pageHinckleyAlpha = alpha; - return this; - } - - public Builder changeDetection(boolean changeDetection) { - this.driftDetection = changeDetection; - return this; - } - - public Builder predictionFunction(int predictionFunction) { - this.predictionFunction = predictionFunction; - return this; - } - - public Builder constantLearningRatioDecay(boolean constantDecay) { - this.constantLearningRatioDecay = constantDecay; - return this; - } - - public Builder learningRatio(double learningRatio) { - this.learningRatio = learningRatio; - return this; - } - - public Builder splitConfidence(double splitConfidence) { - this.splitConfidence = splitConfidence; - return this; - } - - public Builder tieThreshold(double tieThreshold) { - this.tieThreshold = tieThreshold; - return this; - } - - public Builder gracePeriod(int gracePeriod) { - this.gracePeriod = gracePeriod; - return this; - } - - public Builder noAnomalyDetection(boolean noAnomalyDetection) { - this.noAnomalyDetection = noAnomalyDetection; - return this; - } - - public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { - this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; - return this; - } - - public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { - this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; - return this; - } - - public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { - this.anomalyNumInstThreshold = anomalyNumInstThreshold; - return this; - } - - public Builder unorderedRules(boolean unorderedRules) { - this.unorderedRules = unorderedRules; - return this; - } - - public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { - this.numericObserver = numericObserver; - return this; - } - - public Builder voteType(int voteType) { - this.voteType = voteType; - return this; - } - - public AMRulesAggregatorProcessor build() { - return new AMRulesAggregatorProcessor(this); - } - } + else if (event instanceof PredicateContentEvent) { + this.updateRuleSplitNode((PredicateContentEvent) event); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent rce = (RuleContentEvent) event; + if (rce.isRemoving()) { + this.removeRule(rce.getRuleNumberID()); + } + } + + return true; + } + + // Merge predict and train so we only check for covering rules one time + private void processInstanceEvent(InstanceContentEvent instanceEvent) { + Instance instance = instanceEvent.getInstance(); + boolean predictionCovered = false; + boolean trainingCovered = false; + boolean continuePrediction = instanceEvent.isTesting(); + boolean continueTraining = instanceEvent.isTraining(); + + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + Iterator<PassiveRule> ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + if (!continuePrediction && !continueTraining) + break; + + PassiveRule rule = ruleIterator.next(); + + if (rule.isCovering(instance) == true) { + predictionCovered = true; + + if (continuePrediction) { + double[] vote = rule.getPrediction(instance); + double error = rule.getCurrentError(); + errorWeightedVote.addVote(vote, error); + if (!this.unorderedRules) + continuePrediction = false; + } + + if (continueTraining) { + if (!isAnomaly(instance, rule)) { + trainingCovered = true; + rule.updateStatistics(instance); + // Send instance to statistics PIs + sendInstanceToRule(instance, rule.getRuleNumberID()); + + if (!this.unorderedRules) + continueTraining = false; + } + } + } + } + + if (predictionCovered) { + // Combined prediction + ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); + resultStream.put(rce); + } + else if (instanceEvent.isTesting()) { + // predict with default rule + double[] vote = defaultRule.getPrediction(instance); + ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); + resultStream.put(rce); + } + + if (!trainingCovered && instanceEvent.isTraining()) { + // train default rule with this instance + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(), + (RuleActiveRegressionNode) defaultRule.getLearningNode(), + ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other + // branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + this.ruleSet.add(new PassiveRule(this.defaultRule)); + // send to statistics PI + sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); + defaultRule = newDefaultRule; + } + } + } + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and + * its prediction result. + * + * @param prediction + * The predicted class label from the decision tree model. + * @param inEvent + * The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other + * destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + public ErrorWeightedVote newErrorWeightedVote() { + if (voteType == 1) + return new UniformWeightedVote(); + return new InverseErrorWeightedVote(); + } + + /** + * Method to verify if the instance is an anomaly. + * + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, LearningRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Create new rules + */ + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r = newRule(ID); + + if (node != null) + { + if (node.getPerceptron() != null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics == null) + { + double mean; + if (node.getNodeStatistics().getValue(0) > 0) { + mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null) + { + double mean; + if (statistics[0] > 0) { + mean = statistics[1] / statistics[0]; + ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r = new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + /* + * Add predicate/RuleSplitNode for a rule + */ + private void updateRuleSplitNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + if (pce.getRuleSplitNode() != null) + rule.nodeListAdd(pce.getRuleSplitNode()); + if (pce.getLearningNode() != null) + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + /* + * Remove rule + */ + private void removeRule(int ruleID) { + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + ruleSet.remove(rule); + break; + } + } + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics = new double[] { 0.0, 0, 0 }; + this.ruleNumberID = 0; + this.defaultRule = newRule(++this.ruleNumberID); + + this.ruleSet = new LinkedList<PassiveRule>(); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRulesAggregatorProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.statisticsStream = oldProcessor.statisticsStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendInstanceToRule(Instance instance, int ruleID) { + AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); + this.statisticsStream.put(ace); + } + + private void sendAddRuleEvent(int ruleID, ActiveRule rule) { + RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); + this.statisticsStream.put(rce); + } + + /* + * Output streams + */ + public void setStatisticsStream(Stream statisticsStream) { + this.statisticsStream = statisticsStream; + } + + public Stream getStatisticsStream() { + return this.statisticsStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Others + */ + public boolean isRandomizable() { + return true; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private int voteType; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRulesAggregatorProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.numericObserver = processor.numericObserver; + this.voteType = processor.voteType; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder voteType(int voteType) { + this.voteType = voteType; + return this; + } + + public AMRulesAggregatorProcessor build() { + return new AMRulesAggregatorProcessor(this); + } + } }
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java index da820d8..2f1cb18 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java @@ -41,180 +41,180 @@ import com.yahoo.labs.samoa.topology.Stream; * Learner Processor (VAMR). * * @author Anh Thu Vu - * + * */ public class AMRulesStatisticsProcessor implements Processor { - /** + /** * */ - private static final long serialVersionUID = 5268933189695395573L; - - private static final Logger logger = - LoggerFactory.getLogger(AMRulesStatisticsProcessor.class); - - private int processorId; - - private transient List<ActiveRule> ruleSet; - - private Stream outputStream; - - private double splitConfidence; - private double tieThreshold; - private int gracePeriod; - - private int frequency; - - public AMRulesStatisticsProcessor(Builder builder) { - this.splitConfidence = builder.splitConfidence; - this.tieThreshold = builder.tieThreshold; - this.gracePeriod = builder.gracePeriod; - this.frequency = builder.frequency; - } - - @Override - public boolean process(ContentEvent event) { - if (event instanceof AssignmentContentEvent) { - - AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; - trainRuleOnInstance(attrContentEvent.getRuleNumberID(),attrContentEvent.getInstance()); - } - else if (event instanceof RuleContentEvent) { - RuleContentEvent ruleContentEvent = (RuleContentEvent) event; - if (!ruleContentEvent.isRemoving()) { - addRule(ruleContentEvent.getRule()); - } - } - - return false; - } - - /* - * Process input instances - */ - private void trainRuleOnInstance(int ruleID, Instance instance) { - Iterator<ActiveRule> ruleIterator= this.ruleSet.iterator(); - while (ruleIterator.hasNext()) { - ActiveRule rule = ruleIterator.next(); - if (rule.getRuleNumberID() == ruleID) { - // Check (again) for coverage - // Skip anomaly check as Aggregator's perceptron should be well-updated - if (rule.isCovering(instance) == true) { - double error = rule.computeError(instance); //Use adaptive mode error - boolean changeDetected = ((RuleActiveRegressionNode)rule.getLearningNode()).updateChangeDetection(error); - if (changeDetected == true) { - ruleIterator.remove(); - - this.sendRemoveRuleEvent(ruleID); - } else { - rule.updateStatistics(instance); - if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { - if (rule.tryToExpand(this.splitConfidence, this.tieThreshold) ) { - rule.split(); - - // expanded: update Aggregator with new/updated predicate - this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), - (RuleActiveRegressionNode)rule.getLearningNode()); - } - } - } - } - - return; - } - } - } - - private void sendRemoveRuleEvent(int ruleID) { - RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); - this.outputStream.put(rce); - } - - private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { - this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); - } - - /* - * Process control message (regarding adding or removing rules) - */ - private boolean addRule(ActiveRule rule) { - this.ruleSet.add(rule); - return true; - } - - @Override - public void onCreate(int id) { - this.processorId = id; - this.ruleSet = new LinkedList<ActiveRule>(); - } - - @Override - public Processor newProcessor(Processor p) { - AMRulesStatisticsProcessor oldProcessor = (AMRulesStatisticsProcessor)p; - AMRulesStatisticsProcessor newProcessor = - new AMRulesStatisticsProcessor.Builder(oldProcessor).build(); - - newProcessor.setOutputStream(oldProcessor.outputStream); - return newProcessor; - } - - /* - * Builder - */ - public static class Builder { - private double splitConfidence; - private double tieThreshold; - private int gracePeriod; - - private int frequency; - - private Instances dataset; - - public Builder(Instances dataset){ - this.dataset = dataset; - } - - public Builder(AMRulesStatisticsProcessor processor) { - this.splitConfidence = processor.splitConfidence; - this.tieThreshold = processor.tieThreshold; - this.gracePeriod = processor.gracePeriod; - this.frequency = processor.frequency; - } - - public Builder splitConfidence(double splitConfidence) { - this.splitConfidence = splitConfidence; - return this; - } - - public Builder tieThreshold(double tieThreshold) { - this.tieThreshold = tieThreshold; - return this; - } - - public Builder gracePeriod(int gracePeriod) { - this.gracePeriod = gracePeriod; - return this; - } - - public Builder frequency(int frequency) { - this.frequency = frequency; - return this; - } - - public AMRulesStatisticsProcessor build() { - return new AMRulesStatisticsProcessor(this); - } - } - - /* - * Output stream - */ - public void setOutputStream(Stream stream) { - this.outputStream = stream; - } - - public Stream getOutputStream() { - return this.outputStream; - } + private static final long serialVersionUID = 5268933189695395573L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesStatisticsProcessor.class); + + private int processorId; + + private transient List<ActiveRule> ruleSet; + + private Stream outputStream; + + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + public AMRulesStatisticsProcessor(Builder builder) { + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + this.frequency = builder.frequency; + } + + @Override + public boolean process(ContentEvent event) { + if (event instanceof AssignmentContentEvent) { + + AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; + trainRuleOnInstance(attrContentEvent.getRuleNumberID(), attrContentEvent.getInstance()); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent ruleContentEvent = (RuleContentEvent) event; + if (!ruleContentEvent.isRemoving()) { + addRule(ruleContentEvent.getRule()); + } + } + + return false; + } + + /* + * Process input instances + */ + private void trainRuleOnInstance(int ruleID, Instance instance) { + Iterator<ActiveRule> ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.getRuleNumberID() == ruleID) { + // Check (again) for coverage + // Skip anomaly check as Aggregator's perceptron should be well-updated + if (rule.isCovering(instance) == true) { + double error = rule.computeError(instance); // Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + + this.sendRemoveRuleEvent(ruleID); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) { + rule.split(); + + // expanded: update Aggregator with new/updated predicate + this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), + (RuleActiveRegressionNode) rule.getLearningNode()); + } + } + } + } + + return; + } + } + } + + private void sendRemoveRuleEvent(int ruleID) { + RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); + this.outputStream.put(rce); + } + + private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { + this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); + } + + /* + * Process control message (regarding adding or removing rules) + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(rule); + return true; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList<ActiveRule>(); + } + + @Override + public Processor newProcessor(Processor p) { + AMRulesStatisticsProcessor oldProcessor = (AMRulesStatisticsProcessor) p; + AMRulesStatisticsProcessor newProcessor = + new AMRulesStatisticsProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + /* + * Builder + */ + public static class Builder { + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRulesStatisticsProcessor processor) { + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + this.frequency = processor.frequency; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder frequency(int frequency) { + this.frequency = frequency; + return this; + } + + public AMRulesStatisticsProcessor build() { + return new AMRulesStatisticsProcessor(this); + } + } + + /* + * Output stream + */ + public void setOutputStream(Stream stream) { + this.outputStream = stream; + } + + public Stream getOutputStream() { + return this.outputStream; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java index 5a03406..43814b4 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java @@ -27,48 +27,48 @@ import com.yahoo.labs.samoa.instances.Instance; * Forwarded instances from Model Agrregator to Learners/Default Rule Learner. * * @author Anh Thu Vu - * + * */ public class AssignmentContentEvent implements ContentEvent { - /** + /** * */ - private static final long serialVersionUID = 1031695762172836629L; + private static final long serialVersionUID = 1031695762172836629L; + + private int ruleNumberID; + private Instance instance; + + public AssignmentContentEvent() { + this(0, null); + } + + public AssignmentContentEvent(int ruleID, Instance instance) { + this.ruleNumberID = ruleID; + this.instance = instance; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } - private int ruleNumberID; - private Instance instance; - - public AssignmentContentEvent() { - this(0, null); - } - - public AssignmentContentEvent(int ruleID, Instance instance) { - this.ruleNumberID = ruleID; - this.instance = instance; - } - - @Override - public String getKey() { - return Integer.toString(this.ruleNumberID); - } + @Override + public boolean isLastEvent() { + return false; + } - @Override - public void setKey(String key) { - // do nothing - } + public Instance getInstance() { + return this.instance; + } - @Override - public boolean isLastEvent() { - return false; - } - - public Instance getInstance() { - return this.instance; - } - - public int getRuleNumberID() { - return this.ruleNumberID; - } + public int getRuleNumberID() { + return this.ruleNumberID; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java index 69e935a..f6c8934 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java @@ -28,57 +28,58 @@ import com.yahoo.labs.samoa.learners.classifiers.rules.common.RuleSplitNode; * New features (of newly expanded rules) from Learners to Model Aggregators. * * @author Anh Thu Vu - * + * */ public class PredicateContentEvent implements ContentEvent { - /** + /** * */ - private static final long serialVersionUID = 7909435830443732451L; - - private int ruleNumberID; - private RuleSplitNode ruleSplitNode; - private RulePassiveRegressionNode learningNode; - - /* - * Constructor - */ - public PredicateContentEvent() { - this(0, null, null); - } - - public PredicateContentEvent (int ruleID, RuleSplitNode ruleSplitNode, RulePassiveRegressionNode learningNode) { - this.ruleNumberID = ruleID; - this.ruleSplitNode = ruleSplitNode; // is this is null: this is for updating learningNode only - this.learningNode = learningNode; - } - - @Override - public String getKey() { - return Integer.toString(this.ruleNumberID); - } + private static final long serialVersionUID = 7909435830443732451L; + + private int ruleNumberID; + private RuleSplitNode ruleSplitNode; + private RulePassiveRegressionNode learningNode; + + /* + * Constructor + */ + public PredicateContentEvent() { + this(0, null, null); + } + + public PredicateContentEvent(int ruleID, RuleSplitNode ruleSplitNode, RulePassiveRegressionNode learningNode) { + this.ruleNumberID = ruleID; + this.ruleSplitNode = ruleSplitNode; // is this is null: this is for updating + // learningNode only + this.learningNode = learningNode; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; // N/A + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } - @Override - public void setKey(String key) { - // do nothing - } + public RuleSplitNode getRuleSplitNode() { + return this.ruleSplitNode; + } - @Override - public boolean isLastEvent() { - return false; // N/A - } - - public int getRuleNumberID() { - return this.ruleNumberID; - } - - public RuleSplitNode getRuleSplitNode() { - return this.ruleSplitNode; - } - - public RulePassiveRegressionNode getLearningNode() { - return this.learningNode; - } + public RulePassiveRegressionNode getLearningNode() { + return this.learningNode; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java index a9dab4a..ac7aced 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java @@ -24,59 +24,59 @@ import com.yahoo.labs.samoa.core.ContentEvent; import com.yahoo.labs.samoa.learners.classifiers.rules.common.ActiveRule; /** - * New rule from Model Aggregator/Default Rule Learner to Learners - * or removed rule from Learner to Model Aggregators. + * New rule from Model Aggregator/Default Rule Learner to Learners or removed + * rule from Learner to Model Aggregators. * * @author Anh Thu Vu - * + * */ public class RuleContentEvent implements ContentEvent { - - /** + /** * */ - private static final long serialVersionUID = -9046390274402894461L; - - private final int ruleNumberID; - private final ActiveRule addingRule; // for removing rule, we only need the rule's ID - private final boolean isRemoving; - - public RuleContentEvent() { - this(0, null, false); - } - - public RuleContentEvent(int ruleID, ActiveRule rule, boolean isRemoving) { - this.ruleNumberID = ruleID; - this.isRemoving = isRemoving; - this.addingRule = rule; - } + private static final long serialVersionUID = -9046390274402894461L; + + private final int ruleNumberID; + private final ActiveRule addingRule; // for removing rule, we only need the + // rule's ID + private final boolean isRemoving; + + public RuleContentEvent() { + this(0, null, false); + } + + public RuleContentEvent(int ruleID, ActiveRule rule, boolean isRemoving) { + this.ruleNumberID = ruleID; + this.isRemoving = isRemoving; + this.addingRule = rule; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } - @Override - public String getKey() { - return Integer.toString(this.ruleNumberID); - } + public int getRuleNumberID() { + return this.ruleNumberID; + } - @Override - public void setKey(String key) { - // do nothing - } + public ActiveRule getRule() { + return this.addingRule; + } - @Override - public boolean isLastEvent() { - return false; - } - - public int getRuleNumberID() { - return this.ruleNumberID; - } - - public ActiveRule getRule() { - return this.addingRule; - } - - public boolean isRemoving() { - return this.isRemoving; - } + public boolean isRemoving() { + return this.isRemoving; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java index 40d260c..39abbbe 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ActiveLearningNode.java @@ -31,177 +31,178 @@ import org.slf4j.LoggerFactory; import com.yahoo.labs.samoa.instances.Instance; final class ActiveLearningNode extends LearningNode { - /** + /** * */ - private static final long serialVersionUID = -2892102872646338908L; - private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class); - - private double weightSeenAtLastSplitEvaluation; - - private final Map<Integer, String> attributeContentEventKeys; - - private AttributeSplitSuggestion bestSuggestion; - private AttributeSplitSuggestion secondBestSuggestion; - - private final long id; - private final int parallelismHint; - private int suggestionCtr; - private int thrownAwayInstance; - - private boolean isSplitting; - - ActiveLearningNode(double[] classObservation, int parallelismHint) { - super(classObservation); - this.weightSeenAtLastSplitEvaluation = this.getWeightSeen(); - this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate(); - this.attributeContentEventKeys = new HashMap<>(); - this.isSplitting = false; - this.parallelismHint = parallelismHint; - } - - long getId(){ - return id; - } - - protected AttributeBatchContentEvent[] attributeBatchContentEvent; - - public AttributeBatchContentEvent[] getAttributeBatchContentEvent() { - return this.attributeBatchContentEvent; + private static final long serialVersionUID = -2892102872646338908L; + private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class); + + private double weightSeenAtLastSplitEvaluation; + + private final Map<Integer, String> attributeContentEventKeys; + + private AttributeSplitSuggestion bestSuggestion; + private AttributeSplitSuggestion secondBestSuggestion; + + private final long id; + private final int parallelismHint; + private int suggestionCtr; + private int thrownAwayInstance; + + private boolean isSplitting; + + ActiveLearningNode(double[] classObservation, int parallelismHint) { + super(classObservation); + this.weightSeenAtLastSplitEvaluation = this.getWeightSeen(); + this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate(); + this.attributeContentEventKeys = new HashMap<>(); + this.isSplitting = false; + this.parallelismHint = parallelismHint; + } + + long getId() { + return id; + } + + protected AttributeBatchContentEvent[] attributeBatchContentEvent; + + public AttributeBatchContentEvent[] getAttributeBatchContentEvent() { + return this.attributeBatchContentEvent; + } + + public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) { + this.attributeBatchContentEvent = attributeBatchContentEvent; + } + + @Override + void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { + // TODO: what statistics should we keep for unused instance? + if (isSplitting) { // currently throw all instance will splitting + this.thrownAwayInstance++; + return; } + this.observedClassDistribution.addToValue((int) inst.classValue(), + inst.weight()); + // done: parallelize by sending attributes one by one + // TODO: meanwhile, we can try to use the ThreadPool to execute it + // separately + // TODO: parallelize by sending in batch, i.e. split the attributes into + // chunk instead of send the attribute one by one + for (int i = 0; i < inst.numAttributes() - 1; i++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + Integer obsIndex = i; + String key = attributeContentEventKeys.get(obsIndex); - public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) { - this.attributeBatchContentEvent = attributeBatchContentEvent; + if (key == null) { + key = this.generateKey(i); + attributeContentEventKeys.put(obsIndex, key); + } + AttributeContentEvent ace = new AttributeContentEvent.Builder( + this.id, i, key) + .attrValue(inst.value(instAttIndex)) + .classValue((int) inst.classValue()) + .weight(inst.weight()) + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + if (this.attributeBatchContentEvent == null) { + this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1]; + } + if (this.attributeBatchContentEvent[i] == null) { + this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder( + this.id, i, key) + // .attrValue(inst.value(instAttIndex)) + // .classValue((int) inst.classValue()) + // .weight(inst.weight()] + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + } + this.attributeBatchContentEvent[i].add(ace); + // proc.sendToAttributeStream(ace); } - - @Override - void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { - //TODO: what statistics should we keep for unused instance? - if(isSplitting){ //currently throw all instance will splitting - this.thrownAwayInstance++; - return; - } - this.observedClassDistribution.addToValue((int)inst.classValue(), - inst.weight()); - //done: parallelize by sending attributes one by one - //TODO: meanwhile, we can try to use the ThreadPool to execute it separately - //TODO: parallelize by sending in batch, i.e. split the attributes into - //chunk instead of send the attribute one by one - for(int i = 0; i < inst.numAttributes() - 1; i++){ - int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); - Integer obsIndex = i; - String key = attributeContentEventKeys.get(obsIndex); - - if(key == null){ - key = this.generateKey(i); - attributeContentEventKeys.put(obsIndex, key); - } - AttributeContentEvent ace = new AttributeContentEvent.Builder( - this.id, i, key) - .attrValue(inst.value(instAttIndex)) - .classValue((int) inst.classValue()) - .weight(inst.weight()) - .isNominal(inst.attribute(instAttIndex).isNominal()) - .build(); - if (this.attributeBatchContentEvent == null){ - this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1]; - } - if (this.attributeBatchContentEvent[i] == null){ - this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder( - this.id, i, key) - //.attrValue(inst.value(instAttIndex)) - //.classValue((int) inst.classValue()) - //.weight(inst.weight()] - .isNominal(inst.attribute(instAttIndex).isNominal()) - .build(); - } - this.attributeBatchContentEvent[i].add(ace); - //proc.sendToAttributeStream(ace); - } - } - - @Override - double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { - return this.observedClassDistribution.getArrayCopy(); - } - - double getWeightSeen(){ - return this.observedClassDistribution.sumOfValues(); - } - - void setWeightSeenAtLastSplitEvaluation(double weight){ - this.weightSeenAtLastSplitEvaluation = weight; - } - - double getWeightSeenAtLastSplitEvaluation(){ - return this.weightSeenAtLastSplitEvaluation; - } - - void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) { - this.isSplitting = true; - this.suggestionCtr = 0; - this.thrownAwayInstance = 0; - - ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id, - this.getObservedClassDistribution()); - modelAggrProc.sendToControlStream(cce); - } - - void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion){ - //starts comparing from the best suggestion - if(bestSuggestion != null){ - if((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)){ - this.secondBestSuggestion = this.bestSuggestion; - this.bestSuggestion = bestSuggestion; - - if(secondBestSuggestion != null){ - - if((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)){ - this.secondBestSuggestion = secondBestSuggestion; - } - } - }else{ - if((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)){ - this.secondBestSuggestion = bestSuggestion; - } - } - } - - //TODO: optimize the code to use less memory - this.suggestionCtr++; - } - - boolean isSplitting(){ - return this.isSplitting; - } - - void endSplitting(){ - this.isSplitting = false; - logger.trace("wasted instance: {}", this.thrownAwayInstance); - this.thrownAwayInstance = 0; - } - - AttributeSplitSuggestion getDistributedBestSuggestion(){ - return this.bestSuggestion; - } - - AttributeSplitSuggestion getDistributedSecondBestSuggestion(){ - return this.secondBestSuggestion; - } - - boolean isAllSuggestionsCollected(){ - return (this.suggestionCtr == this.parallelismHint); - } - - private static int modelAttIndexToInstanceAttIndex(int index, Instance inst){ - return inst.classIndex() > index ? index : index + 1; - } - - private String generateKey(int obsIndex){ - final int prime = 31; - int result = 1; - result = prime * result + (int) (this.id ^ (this.id >>> 32)); - result = prime * result + obsIndex; - return Integer.toString(result); - } + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { + return this.observedClassDistribution.getArrayCopy(); + } + + double getWeightSeen() { + return this.observedClassDistribution.sumOfValues(); + } + + void setWeightSeenAtLastSplitEvaluation(double weight) { + this.weightSeenAtLastSplitEvaluation = weight; + } + + double getWeightSeenAtLastSplitEvaluation() { + return this.weightSeenAtLastSplitEvaluation; + } + + void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) { + this.isSplitting = true; + this.suggestionCtr = 0; + this.thrownAwayInstance = 0; + + ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id, + this.getObservedClassDistribution()); + modelAggrProc.sendToControlStream(cce); + } + + void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion) { + // starts comparing from the best suggestion + if (bestSuggestion != null) { + if ((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)) { + this.secondBestSuggestion = this.bestSuggestion; + this.bestSuggestion = bestSuggestion; + + if (secondBestSuggestion != null) { + + if ((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { + this.secondBestSuggestion = secondBestSuggestion; + } + } + } else { + if ((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { + this.secondBestSuggestion = bestSuggestion; + } + } + } + + // TODO: optimize the code to use less memory + this.suggestionCtr++; + } + + boolean isSplitting() { + return this.isSplitting; + } + + void endSplitting() { + this.isSplitting = false; + logger.trace("wasted instance: {}", this.thrownAwayInstance); + this.thrownAwayInstance = 0; + } + + AttributeSplitSuggestion getDistributedBestSuggestion() { + return this.bestSuggestion; + } + + AttributeSplitSuggestion getDistributedSecondBestSuggestion() { + return this.secondBestSuggestion; + } + + boolean isAllSuggestionsCollected() { + return (this.suggestionCtr == this.parallelismHint); + } + + private static int modelAttIndexToInstanceAttIndex(int index, Instance inst) { + return inst.classIndex() > index ? index : index + 1; + } + + private String generateKey(int obsIndex) { + final int prime = 31; + int result = 1; + result = prime * result + (int) (this.id ^ (this.id >>> 32)); + result = prime * result + obsIndex; + return Integer.toString(result); + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java index 691d0fb..c67efcb 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java @@ -25,110 +25,112 @@ import java.util.LinkedList; import java.util.List; /** - * Attribute Content Event represents the instances that split vertically - * based on their attribute + * Attribute Content Event represents the instances that split vertically based + * on their attribute + * * @author Arinto Murdopo - * + * */ final class AttributeBatchContentEvent implements ContentEvent { - private static final long serialVersionUID = 6652815649846676832L; - - private final long learningNodeId; - private final int obsIndex; - private final List<ContentEvent> contentEventList; - private final transient String key; - private final boolean isNominal; - - public AttributeBatchContentEvent(){ - learningNodeId = -1; - obsIndex = -1; - contentEventList = new LinkedList<>(); - key = ""; - isNominal = true; - } - - private AttributeBatchContentEvent(Builder builder){ - this.learningNodeId = builder.learningNodeId; - this.obsIndex = builder.obsIndex; - this.contentEventList = new LinkedList<>(); - if (builder.contentEvent != null) { - this.contentEventList.add(builder.contentEvent); - } - this.isNominal = builder.isNominal; - this.key = builder.key; - } - - public void add(ContentEvent contentEvent){ - this.contentEventList.add(contentEvent); - } - - @Override - public String getKey() { - return this.key; - } - - @Override - public void setKey(String str) { - //do nothing, maybe useful when we want to reuse the object for serialization/deserialization purpose - } - - @Override - public boolean isLastEvent() { - return false; - } - - long getLearningNodeId(){ - return this.learningNodeId; - } - - int getObsIndex(){ - return this.obsIndex; - } - - public List<ContentEvent> getContentEventList(){ - return this.contentEventList; - } - - boolean isNominal(){ - return this.isNominal; - } - - static final class Builder{ - - //required parameters - private final long learningNodeId; - private final int obsIndex; - private final String key; - - private ContentEvent contentEvent; - private boolean isNominal = false; - - Builder(long id, int obsIndex, String key){ - this.learningNodeId = id; - this.obsIndex = obsIndex; - this.key = key; - } - - private Builder(long id, int obsIndex){ - this.learningNodeId = id; - this.obsIndex = obsIndex; - this.key = ""; - } - - Builder contentEvent(ContentEvent contentEvent){ - this.contentEvent = contentEvent; - return this; - } - - Builder isNominal(boolean val){ - this.isNominal = val; - return this; - } - - AttributeBatchContentEvent build(){ - return new AttributeBatchContentEvent(this); - } - } - + private static final long serialVersionUID = 6652815649846676832L; + + private final long learningNodeId; + private final int obsIndex; + private final List<ContentEvent> contentEventList; + private final transient String key; + private final boolean isNominal; + + public AttributeBatchContentEvent() { + learningNodeId = -1; + obsIndex = -1; + contentEventList = new LinkedList<>(); + key = ""; + isNominal = true; + } + + private AttributeBatchContentEvent(Builder builder) { + this.learningNodeId = builder.learningNodeId; + this.obsIndex = builder.obsIndex; + this.contentEventList = new LinkedList<>(); + if (builder.contentEvent != null) { + this.contentEventList.add(builder.contentEvent); + } + this.isNominal = builder.isNominal; + this.key = builder.key; + } + + public void add(ContentEvent contentEvent) { + this.contentEventList.add(contentEvent); + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + // do nothing, maybe useful when we want to reuse the object for + // serialization/deserialization purpose + } + + @Override + public boolean isLastEvent() { + return false; + } + + long getLearningNodeId() { + return this.learningNodeId; + } + + int getObsIndex() { + return this.obsIndex; + } + + public List<ContentEvent> getContentEventList() { + return this.contentEventList; + } + + boolean isNominal() { + return this.isNominal; + } + + static final class Builder { + + // required parameters + private final long learningNodeId; + private final int obsIndex; + private final String key; + + private ContentEvent contentEvent; + private boolean isNominal = false; + + Builder(long id, int obsIndex, String key) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = key; + } + + private Builder(long id, int obsIndex) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = ""; + } + + Builder contentEvent(ContentEvent contentEvent) { + this.contentEvent = contentEvent; + return this; + } + + Builder isNominal(boolean val) { + this.isNominal = val; + return this; + } + + AttributeBatchContentEvent build() { + return new AttributeBatchContentEvent(this); + } + } + } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java index 4cbdd95..ca45d30 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/AttributeContentEvent.java @@ -28,195 +28,198 @@ import com.esotericsoftware.kryo.io.Output; import com.yahoo.labs.samoa.core.ContentEvent; /** - * Attribute Content Event represents the instances that split vertically - * based on their attribute + * Attribute Content Event represents the instances that split vertically based + * on their attribute + * * @author Arinto Murdopo - * + * */ public final class AttributeContentEvent implements ContentEvent { - private static final long serialVersionUID = 6652815649846676832L; - - private final long learningNodeId; - private final int obsIndex; - private final double attrVal; - private final int classVal; - private final double weight; - private final transient String key; - private final boolean isNominal; - - public AttributeContentEvent(){ - learningNodeId = -1; - obsIndex = -1; - attrVal = 0.0; - classVal = -1; - weight = 0.0; - key = ""; - isNominal = true; - } - - private AttributeContentEvent(Builder builder){ - this.learningNodeId = builder.learningNodeId; - this.obsIndex = builder.obsIndex; - this.attrVal = builder.attrVal; - this.classVal = builder.classVal; - this.weight = builder.weight; - this.isNominal = builder.isNominal; - this.key = builder.key; - } - - @Override - public String getKey() { - return this.key; - } - - @Override - public void setKey(String str) { - //do nothing, maybe useful when we want to reuse the object for serialization/deserialization purpose - } - - @Override - public boolean isLastEvent() { - return false; - } - - long getLearningNodeId(){ - return this.learningNodeId; - } - - int getObsIndex(){ - return this.obsIndex; - } - - int getClassVal(){ - return this.classVal; - } - - double getAttrVal(){ - return this.attrVal; - } - - double getWeight(){ - return this.weight; - } - - boolean isNominal(){ - return this.isNominal; - } - - static final class Builder{ - - //required parameters - private final long learningNodeId; - private final int obsIndex; - private final String key; - - //optional parameters - private double attrVal = 0.0; - private int classVal = 0; - private double weight = 0.0; - private boolean isNominal = false; - - Builder(long id, int obsIndex, String key){ - this.learningNodeId = id; - this.obsIndex = obsIndex; - this.key = key; - } - - private Builder(long id, int obsIndex){ - this.learningNodeId = id; - this.obsIndex = obsIndex; - this.key = ""; - } - - Builder attrValue(double val){ - this.attrVal = val; - return this; - } - - Builder classValue(int val){ - this.classVal = val; - return this; - } - - Builder weight(double val){ - this.weight = val; - return this; - } - - Builder isNominal(boolean val){ - this.isNominal = val; - return this; - } - - AttributeContentEvent build(){ - return new AttributeContentEvent(this); - } - } - - /** - * The Kryo serializer class for AttributeContentEvent when executing on top of Storm. - * This class allow us to change the precision of the statistics. - * @author Arinto Murdopo - * - */ - public static final class AttributeCESerializer extends Serializer<AttributeContentEvent>{ - - private static double PRECISION = 1000000.0; - @Override - public void write(Kryo kryo, Output output, AttributeContentEvent event) { - output.writeLong(event.learningNodeId, true); - output.writeInt(event.obsIndex, true); - output.writeDouble(event.attrVal, PRECISION, true); - output.writeInt(event.classVal, true); - output.writeDouble(event.weight, PRECISION, true); - output.writeBoolean(event.isNominal); - } - - @Override - public AttributeContentEvent read(Kryo kryo, Input input, - Class<AttributeContentEvent> type) { - AttributeContentEvent ace - = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) - .attrValue(input.readDouble(PRECISION, true)) - .classValue(input.readInt(true)) - .weight(input.readDouble(PRECISION, true)) - .isNominal(input.readBoolean()) - .build(); - return ace; - } - } - - /** - * The Kryo serializer class for AttributeContentEvent when executing on top of Storm - * with full precision of the statistics. - * @author Arinto Murdopo - * - */ - public static final class AttributeCEFullPrecSerializer extends Serializer<AttributeContentEvent>{ - - @Override - public void write(Kryo kryo, Output output, AttributeContentEvent event) { - output.writeLong(event.learningNodeId, true); - output.writeInt(event.obsIndex, true); - output.writeDouble(event.attrVal); - output.writeInt(event.classVal, true); - output.writeDouble(event.weight); - output.writeBoolean(event.isNominal); - } - - @Override - public AttributeContentEvent read(Kryo kryo, Input input, - Class<AttributeContentEvent> type) { - AttributeContentEvent ace - = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) - .attrValue(input.readDouble()) - .classValue(input.readInt(true)) - .weight(input.readDouble()) - .isNominal(input.readBoolean()) - .build(); - return ace; - } - - } + private static final long serialVersionUID = 6652815649846676832L; + + private final long learningNodeId; + private final int obsIndex; + private final double attrVal; + private final int classVal; + private final double weight; + private final transient String key; + private final boolean isNominal; + + public AttributeContentEvent() { + learningNodeId = -1; + obsIndex = -1; + attrVal = 0.0; + classVal = -1; + weight = 0.0; + key = ""; + isNominal = true; + } + + private AttributeContentEvent(Builder builder) { + this.learningNodeId = builder.learningNodeId; + this.obsIndex = builder.obsIndex; + this.attrVal = builder.attrVal; + this.classVal = builder.classVal; + this.weight = builder.weight; + this.isNominal = builder.isNominal; + this.key = builder.key; + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + // do nothing, maybe useful when we want to reuse the object for + // serialization/deserialization purpose + } + + @Override + public boolean isLastEvent() { + return false; + } + + long getLearningNodeId() { + return this.learningNodeId; + } + + int getObsIndex() { + return this.obsIndex; + } + + int getClassVal() { + return this.classVal; + } + + double getAttrVal() { + return this.attrVal; + } + + double getWeight() { + return this.weight; + } + + boolean isNominal() { + return this.isNominal; + } + + static final class Builder { + + // required parameters + private final long learningNodeId; + private final int obsIndex; + private final String key; + + // optional parameters + private double attrVal = 0.0; + private int classVal = 0; + private double weight = 0.0; + private boolean isNominal = false; + + Builder(long id, int obsIndex, String key) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = key; + } + + private Builder(long id, int obsIndex) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = ""; + } + + Builder attrValue(double val) { + this.attrVal = val; + return this; + } + + Builder classValue(int val) { + this.classVal = val; + return this; + } + + Builder weight(double val) { + this.weight = val; + return this; + } + + Builder isNominal(boolean val) { + this.isNominal = val; + return this; + } + + AttributeContentEvent build() { + return new AttributeContentEvent(this); + } + } + + /** + * The Kryo serializer class for AttributeContentEvent when executing on top + * of Storm. This class allow us to change the precision of the statistics. + * + * @author Arinto Murdopo + * + */ + public static final class AttributeCESerializer extends Serializer<AttributeContentEvent> { + + private static double PRECISION = 1000000.0; + + @Override + public void write(Kryo kryo, Output output, AttributeContentEvent event) { + output.writeLong(event.learningNodeId, true); + output.writeInt(event.obsIndex, true); + output.writeDouble(event.attrVal, PRECISION, true); + output.writeInt(event.classVal, true); + output.writeDouble(event.weight, PRECISION, true); + output.writeBoolean(event.isNominal); + } + + @Override + public AttributeContentEvent read(Kryo kryo, Input input, + Class<AttributeContentEvent> type) { + AttributeContentEvent ace = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) + .attrValue(input.readDouble(PRECISION, true)) + .classValue(input.readInt(true)) + .weight(input.readDouble(PRECISION, true)) + .isNominal(input.readBoolean()) + .build(); + return ace; + } + } + + /** + * The Kryo serializer class for AttributeContentEvent when executing on top + * of Storm with full precision of the statistics. + * + * @author Arinto Murdopo + * + */ + public static final class AttributeCEFullPrecSerializer extends Serializer<AttributeContentEvent> { + + @Override + public void write(Kryo kryo, Output output, AttributeContentEvent event) { + output.writeLong(event.learningNodeId, true); + output.writeInt(event.obsIndex, true); + output.writeDouble(event.attrVal); + output.writeInt(event.classVal, true); + output.writeDouble(event.weight); + output.writeBoolean(event.isNominal); + } + + @Override + public AttributeContentEvent read(Kryo kryo, Input input, + Class<AttributeContentEvent> type) { + AttributeContentEvent ace = new AttributeContentEvent.Builder(input.readLong(true), input.readInt(true)) + .attrValue(input.readDouble()) + .classValue(input.readInt(true)) + .weight(input.readDouble()) + .isNominal(input.readBoolean()) + .build(); + return ace; + } + + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java index 52f4685..8113a41 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ComputeContentEvent.java @@ -26,117 +26,121 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; /** - * Compute content event is the message that is sent by Model Aggregator Processor - * to request Local Statistic PI to start the local statistic calculation for splitting + * Compute content event is the message that is sent by Model Aggregator + * Processor to request Local Statistic PI to start the local statistic + * calculation for splitting + * * @author Arinto Murdopo - * + * */ public final class ComputeContentEvent extends ControlContentEvent { - - private static final long serialVersionUID = 5590798490073395190L; - - private final double[] preSplitDist; - private final long splitId; - - public ComputeContentEvent(){ - super(-1); - preSplitDist = null; - splitId = -1; - } - - ComputeContentEvent(long splitId, long id, double[] preSplitDist) { - super(id); - //this.preSplitDist = Arrays.copyOf(preSplitDist, preSplitDist.length); - this.preSplitDist = preSplitDist; - this.splitId = splitId; - } - - @Override - LocStatControl getType() { - return LocStatControl.COMPUTE; - } - - double[] getPreSplitDist(){ - return this.preSplitDist; - } - - long getSplitId(){ - return this.splitId; - } - - /** - * The Kryo serializer class for ComputeContentEevent when executing on top of Storm. - * This class allow us to change the precision of the statistics. - * @author Arinto Murdopo - * - */ - public static final class ComputeCESerializer extends Serializer<ComputeContentEvent>{ - - private static double PRECISION = 1000000.0; - - @Override - public void write(Kryo kryo, Output output, ComputeContentEvent object) { - output.writeLong(object.splitId, true); - output.writeLong(object.learningNodeId, true); - - output.writeInt(object.preSplitDist.length, true); - for(int i = 0; i < object.preSplitDist.length; i++){ - output.writeDouble(object.preSplitDist[i], PRECISION, true); - } - } - - @Override - public ComputeContentEvent read(Kryo kryo, Input input, - Class<ComputeContentEvent> type) { - long splitId = input.readLong(true); - long learningNodeId = input.readLong(true); - - int dataLength = input.readInt(true); - double[] preSplitDist = new double[dataLength]; - - for(int i = 0; i < dataLength; i++){ - preSplitDist[i] = input.readDouble(PRECISION, true); - } - - return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); - } - } - - /** - * The Kryo serializer class for ComputeContentEevent when executing on top of Storm - * with full precision of the statistics. - * @author Arinto Murdopo - * - */ - public static final class ComputeCEFullPrecSerializer extends Serializer<ComputeContentEvent>{ - - @Override - public void write(Kryo kryo, Output output, ComputeContentEvent object) { - output.writeLong(object.splitId, true); - output.writeLong(object.learningNodeId, true); - - output.writeInt(object.preSplitDist.length, true); - for(int i = 0; i < object.preSplitDist.length; i++){ - output.writeDouble(object.preSplitDist[i]); - } - } - - @Override - public ComputeContentEvent read(Kryo kryo, Input input, - Class<ComputeContentEvent> type) { - long splitId = input.readLong(true); - long learningNodeId = input.readLong(true); - - int dataLength = input.readInt(true); - double[] preSplitDist = new double[dataLength]; - - for(int i = 0; i < dataLength; i++){ - preSplitDist[i] = input.readDouble(); - } - - return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); - } - - } + + private static final long serialVersionUID = 5590798490073395190L; + + private final double[] preSplitDist; + private final long splitId; + + public ComputeContentEvent() { + super(-1); + preSplitDist = null; + splitId = -1; + } + + ComputeContentEvent(long splitId, long id, double[] preSplitDist) { + super(id); + // this.preSplitDist = Arrays.copyOf(preSplitDist, preSplitDist.length); + this.preSplitDist = preSplitDist; + this.splitId = splitId; + } + + @Override + LocStatControl getType() { + return LocStatControl.COMPUTE; + } + + double[] getPreSplitDist() { + return this.preSplitDist; + } + + long getSplitId() { + return this.splitId; + } + + /** + * The Kryo serializer class for ComputeContentEevent when executing on top of + * Storm. This class allow us to change the precision of the statistics. + * + * @author Arinto Murdopo + * + */ + public static final class ComputeCESerializer extends Serializer<ComputeContentEvent> { + + private static double PRECISION = 1000000.0; + + @Override + public void write(Kryo kryo, Output output, ComputeContentEvent object) { + output.writeLong(object.splitId, true); + output.writeLong(object.learningNodeId, true); + + output.writeInt(object.preSplitDist.length, true); + for (int i = 0; i < object.preSplitDist.length; i++) { + output.writeDouble(object.preSplitDist[i], PRECISION, true); + } + } + + @Override + public ComputeContentEvent read(Kryo kryo, Input input, + Class<ComputeContentEvent> type) { + long splitId = input.readLong(true); + long learningNodeId = input.readLong(true); + + int dataLength = input.readInt(true); + double[] preSplitDist = new double[dataLength]; + + for (int i = 0; i < dataLength; i++) { + preSplitDist[i] = input.readDouble(PRECISION, true); + } + + return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); + } + } + + /** + * The Kryo serializer class for ComputeContentEevent when executing on top of + * Storm with full precision of the statistics. + * + * @author Arinto Murdopo + * + */ + public static final class ComputeCEFullPrecSerializer extends Serializer<ComputeContentEvent> { + + @Override + public void write(Kryo kryo, Output output, ComputeContentEvent object) { + output.writeLong(object.splitId, true); + output.writeLong(object.learningNodeId, true); + + output.writeInt(object.preSplitDist.length, true); + for (int i = 0; i < object.preSplitDist.length; i++) { + output.writeDouble(object.preSplitDist[i]); + } + } + + @Override + public ComputeContentEvent read(Kryo kryo, Input input, + Class<ComputeContentEvent> type) { + long splitId = input.readLong(true); + long learningNodeId = input.readLong(true); + + int dataLength = input.readInt(true); + double[] preSplitDist = new double[dataLength]; + + for (int i = 0; i < dataLength; i++) { + preSplitDist[i] = input.readDouble(); + } + + return new ComputeContentEvent(splitId, learningNodeId, preSplitDist); + } + + } }
