http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/common/TargetMean.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/common/TargetMean.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/common/TargetMean.java new file mode 100644 index 0000000..a76d85b --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/common/TargetMean.java @@ -0,0 +1,223 @@ +/* + * TargetMean.java + * Copyright (C) 2014 University of Porto, Portugal + * @author J. Duarte, A. Bifet, J. Gama + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * + */ +package org.apache.samoa.learners.classifiers.rules.common; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * Prediction scheme using TargetMean: + * TargetMean - Returns the mean of the target variable of the training instances + * + * @author Joao Duarte + * + * */ + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.classifiers.AbstractClassifier; +import org.apache.samoa.moa.classifiers.Regressor; +import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.StringUtils; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.Serializer; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; +import com.github.javacliparser.FloatOption; + +public class TargetMean extends AbstractClassifier implements Regressor { + + /** + * + */ + protected long n; + protected double sum; + protected double errorSum; + protected double nError; + private double fadingErrorFactor; + + private static final long serialVersionUID = 7152547322803559115L; + + public FloatOption fadingErrorFactorOption = new FloatOption( + "fadingErrorFactor", 'e', + "Fading error factor for the TargetMean accumulated error", 0.99, 0, 1); + + @Override + public boolean isRandomizable() { + return false; + } + + @Override + public double[] getVotesForInstance(Instance inst) { + return getVotesForInstance(); + } + + public double[] getVotesForInstance() { + double[] currentMean = new double[1]; + if (n > 0) + currentMean[0] = sum / n; + else + currentMean[0] = 0; + return currentMean; + } + + @Override + public void resetLearningImpl() { + sum = 0; + n = 0; + errorSum = Double.MAX_VALUE; + nError = 0; + } + + @Override + public void trainOnInstanceImpl(Instance inst) { + updateAccumulatedError(inst); + ++this.n; + this.sum += inst.classValue(); + } + + protected void updateAccumulatedError(Instance inst) { + double mean = 0; + nError = 1 + fadingErrorFactor * nError; + if (n > 0) + mean = sum / n; + errorSum = Math.abs(inst.classValue() - mean) + fadingErrorFactor * errorSum; + } + + @Override + protected Measurement[] getModelMeasurementsImpl() { + return null; + } + + @Override + public void getModelDescription(StringBuilder out, int indent) { + StringUtils.appendIndented(out, indent, "Current Mean: " + this.sum / this.n); + StringUtils.appendNewline(out); + + } + + /* + * JD Resets the learner but initializes with a starting point + */ + public void reset(double currentMean, long numberOfInstances) { + this.sum = currentMean * numberOfInstances; + this.n = numberOfInstances; + this.resetError(); + } + + /* + * JD Resets the learner but initializes with a starting point + */ + public double getCurrentError() { + if (this.nError > 0) + return this.errorSum / this.nError; + else + return Double.MAX_VALUE; + } + + public TargetMean(TargetMean t) { + super(); + this.n = t.n; + this.sum = t.sum; + this.errorSum = t.errorSum; + this.nError = t.nError; + this.fadingErrorFactor = t.fadingErrorFactor; + this.fadingErrorFactorOption = t.fadingErrorFactorOption; + } + + public TargetMean(TargetMeanData td) { + this(); + this.n = td.n; + this.sum = td.sum; + this.errorSum = td.errorSum; + this.nError = td.nError; + this.fadingErrorFactor = td.fadingErrorFactor; + this.fadingErrorFactorOption.setValue(td.fadingErrorFactorOptionValue); + } + + public TargetMean() { + super(); + fadingErrorFactor = fadingErrorFactorOption.getValue(); + } + + public void resetError() { + this.errorSum = 0; + this.nError = 0; + } + + public static class TargetMeanData { + private long n; + private double sum; + private double errorSum; + private double nError; + private double fadingErrorFactor; + private double fadingErrorFactorOptionValue; + + public TargetMeanData() { + + } + + public TargetMeanData(TargetMean tm) { + this.n = tm.n; + this.sum = tm.sum; + this.errorSum = tm.errorSum; + this.nError = tm.nError; + this.fadingErrorFactor = tm.fadingErrorFactor; + if (tm.fadingErrorFactorOption != null) + this.fadingErrorFactorOptionValue = tm.fadingErrorFactorOption.getValue(); + else + this.fadingErrorFactorOptionValue = 0.99; + } + + public TargetMean build() { + return new TargetMean(this); + } + } + + public static final class TargetMeanSerializer extends Serializer<TargetMean> { + + @Override + public void write(Kryo kryo, Output output, TargetMean t) { + kryo.writeObjectOrNull(output, new TargetMeanData(t), TargetMeanData.class); + } + + @Override + public TargetMean read(Kryo kryo, Input input, Class<TargetMean> type) { + TargetMeanData data = kryo.readObjectOrNull(input, TargetMeanData.class); + return data.build(); + } + } +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java new file mode 100644 index 0000000..9f2b9c2 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRDefaultRuleProcessor.java @@ -0,0 +1,336 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.Perceptron; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Default Rule Learner Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRDefaultRuleProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 23702084591044447L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRDefaultRuleProcessor.class); + + private int processorId; + + // Default rule + protected transient ActiveRule defaultRule; + protected transient int ruleNumberID; + protected transient double[] statistics; + + // SAMOA Stream + private Stream ruleStream; + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + + /* + * Constructor + */ + public AMRDefaultRuleProcessor(Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.numericObserver = builder.numericObserver; + } + + @Override + public boolean process(ContentEvent event) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + // predict + if (instanceEvent.isTesting()) { + this.predictOnInstance(instanceEvent); + } + + // train + if (instanceEvent.isTraining()) { + this.trainOnInstance(instanceEvent); + } + + return false; + } + + /* + * Prediction + */ + private void predictOnInstance(InstanceContentEvent instanceEvent) { + double[] vote = defaultRule.getPrediction(instanceEvent.getInstance()); + ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); + resultStream.put(rce); + } + + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + /* + * Training + */ + private void trainOnInstance(InstanceContentEvent instanceEvent) { + this.trainOnInstanceImpl(instanceEvent.getInstance()); + } + + public void trainOnInstanceImpl(Instance instance) { + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(), + (RuleActiveRegressionNode) defaultRule.getLearningNode(), + ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + // send out the new rule + sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); + defaultRule = newDefaultRule; + } + } + } + + /* + * Create new rules + */ + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r = newRule(ID); + + if (node != null) + { + if (node.getPerceptron() != null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics == null) + { + double mean; + if (node.getNodeStatistics().getValue(0) > 0) { + mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null) + { + double mean; + if (statistics[0] > 0) { + mean = statistics[1] / statistics[0]; + ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r = new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics = new double[] { 0.0, 0, 0 }; + this.ruleNumberID = 0; + this.defaultRule = newRule(++this.ruleNumberID); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRDefaultRuleProcessor oldProcessor = (AMRDefaultRuleProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRDefaultRuleProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.ruleStream = oldProcessor.ruleStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendAddRuleEvent(int ruleID, ActiveRule rule) { + RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); + this.ruleStream.put(rce); + } + + /* + * Output streams + */ + public void setRuleStream(Stream ruleStream) { + this.ruleStream = ruleStream; + } + + public Stream getRuleStream() { + return this.ruleStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRDefaultRuleProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.numericObserver = processor.numericObserver; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public AMRDefaultRuleProcessor build() { + return new AMRDefaultRuleProcessor(this); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java new file mode 100644 index 0000000..a718945 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRLearnerProcessor.java @@ -0,0 +1,259 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.LearningRule; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RuleSplitNode; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Learner Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRLearnerProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -2302897295090248013L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRLearnerProcessor.class); + + private int processorId; + + private transient List<ActiveRule> ruleSet; + + private Stream outputStream; + + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + public AMRLearnerProcessor(Builder builder) { + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + } + + @Override + public boolean process(ContentEvent event) { + if (event instanceof AssignmentContentEvent) { + AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; + trainRuleOnInstance(attrContentEvent.getRuleNumberID(), attrContentEvent.getInstance()); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent ruleContentEvent = (RuleContentEvent) event; + if (!ruleContentEvent.isRemoving()) { + addRule(ruleContentEvent.getRule()); + } + } + + return false; + } + + /* + * Process input instances + */ + private void trainRuleOnInstance(int ruleID, Instance instance) { + // System.out.println("Processor:"+this.processorId+": Rule:"+ruleID+" -> Counter="+counter); + Iterator<ActiveRule> ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.getRuleNumberID() == ruleID) { + // Check (again) for coverage + if (rule.isCovering(instance) == true) { + double error = rule.computeError(instance); // Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + + this.sendRemoveRuleEvent(ruleID); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) { + rule.split(); + + // expanded: update Aggregator with new/updated predicate + this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), + (RuleActiveRegressionNode) rule.getLearningNode()); + } + + } + + } + } + + return; + } + } + } + + private boolean isAnomaly(Instance instance, LearningRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + private void sendRemoveRuleEvent(int ruleID) { + RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); + this.outputStream.put(rce); + } + + private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { + this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); + } + + /* + * Process control message (regarding adding or removing rules) + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(rule); + return true; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList<ActiveRule>(); + } + + @Override + public Processor newProcessor(Processor p) { + AMRLearnerProcessor oldProcessor = (AMRLearnerProcessor) p; + AMRLearnerProcessor newProcessor = + new AMRLearnerProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + /* + * Builder + */ + public static class Builder { + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRLearnerProcessor processor) { + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public AMRLearnerProcessor build() { + return new AMRLearnerProcessor(this); + } + } + + /* + * Output stream + */ + public void setOutputStream(Stream stream) { + this.outputStream = stream; + } + + public Stream getOutputStream() { + return this.outputStream; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java new file mode 100644 index 0000000..bf43ad0 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRRuleSetProcessor.java @@ -0,0 +1,373 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.LearningRule; +import org.apache.samoa.learners.classifiers.rules.common.PassiveRule; +import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Model Aggregator Processor (HAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRRuleSetProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -6544096255649379334L; + private static final Logger logger = LoggerFactory.getLogger(AMRRuleSetProcessor.class); + + private int processorId; + + // Rules & default rule + protected transient List<PassiveRule> ruleSet; + + // SAMOA Stream + private Stream statisticsStream; + private Stream resultStream; + private Stream defaultRuleStream; + + // Options + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected int voteType; + + /* + * Constructor + */ + public AMRRuleSetProcessor(Builder builder) { + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.voteType = builder.voteType; + } + + /* + * (non-Javadoc) + * + * @see org.apache.samoa.core.Processor#process(org.apache.samoa.core. + * ContentEvent) + */ + @Override + public boolean process(ContentEvent event) { + if (event instanceof InstanceContentEvent) { + this.processInstanceEvent((InstanceContentEvent) event); + } + else if (event instanceof PredicateContentEvent) { + PredicateContentEvent pce = (PredicateContentEvent) event; + if (pce.getRuleSplitNode() == null) { + this.updateLearningNode(pce); + } + else { + this.updateRuleSplitNode(pce); + } + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent rce = (RuleContentEvent) event; + if (rce.isRemoving()) { + this.removeRule(rce.getRuleNumberID()); + } + else { + addRule(rce.getRule()); + } + } + return true; + } + + private void processInstanceEvent(InstanceContentEvent instanceEvent) { + Instance instance = instanceEvent.getInstance(); + boolean predictionCovered = false; + boolean trainingCovered = false; + boolean continuePrediction = instanceEvent.isTesting(); + boolean continueTraining = instanceEvent.isTraining(); + + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + for (PassiveRule aRuleSet : this.ruleSet) { + if (!continuePrediction && !continueTraining) + break; + + if (aRuleSet.isCovering(instance)) { + predictionCovered = true; + + if (continuePrediction) { + double[] vote = aRuleSet.getPrediction(instance); + double error = aRuleSet.getCurrentError(); + errorWeightedVote.addVote(vote, error); + if (!this.unorderedRules) + continuePrediction = false; + } + + if (continueTraining) { + if (!isAnomaly(instance, aRuleSet)) { + trainingCovered = true; + aRuleSet.updateStatistics(instance); + + // Send instance to statistics PIs + sendInstanceToRule(instance, aRuleSet.getRuleNumberID()); + + if (!this.unorderedRules) + continueTraining = false; + } + } + } + } + + if (predictionCovered) { + // Combined prediction + ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); + resultStream.put(rce); + } + + boolean defaultPrediction = instanceEvent.isTesting() && !predictionCovered; + boolean defaultTraining = instanceEvent.isTraining() && !trainingCovered; + if (defaultPrediction || defaultTraining) { + instanceEvent.setTesting(defaultPrediction); + instanceEvent.setTraining(defaultTraining); + this.defaultRuleStream.put(instanceEvent); + } + } + + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + public ErrorWeightedVote newErrorWeightedVote() { + // TODO: do a reset instead of init a new object + if (voteType == 1) + return new UniformWeightedVote(); + return new InverseErrorWeightedVote(); + } + + /** + * Method to verify if the instance is an anomaly. + * + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, LearningRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (!this.noAnomalyDetection) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Add predicate/RuleSplitNode for a rule + */ + private void updateRuleSplitNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + rule.nodeListAdd(pce.getRuleSplitNode()); + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + private void updateLearningNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + /* + * Add new rule/Remove rule + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(new PassiveRule(rule)); + return true; + } + + private void removeRule(int ruleID) { + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + ruleSet.remove(rule); + break; + } + } + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList<PassiveRule>(); + + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRRuleSetProcessor oldProcessor = (AMRRuleSetProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRRuleSetProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.statisticsStream = oldProcessor.statisticsStream; + newProcessor.defaultRuleStream = oldProcessor.defaultRuleStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendInstanceToRule(Instance instance, int ruleID) { + AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); + this.statisticsStream.put(ace); + } + + /* + * Output streams + */ + public void setStatisticsStream(Stream statisticsStream) { + this.statisticsStream = statisticsStream; + } + + public Stream getStatisticsStream() { + return this.statisticsStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + public Stream getDefaultRuleStream() { + return this.defaultRuleStream; + } + + public void setDefaultRuleStream(Stream defaultRuleStream) { + this.defaultRuleStream = defaultRuleStream; + } + + /* + * Builder + */ + public static class Builder { + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + // private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private int voteType; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRRuleSetProcessor processor) { + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.voteType = processor.voteType; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder voteType(int voteType) { + this.voteType = voteType; + return this; + } + + public AMRRuleSetProcessor build() { + return new AMRRuleSetProcessor(this); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java new file mode 100644 index 0000000..2131db4 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesAggregatorProcessor.java @@ -0,0 +1,530 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.InstanceContentEvent; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.LearningRule; +import org.apache.samoa.learners.classifiers.rules.common.PassiveRule; +import org.apache.samoa.learners.classifiers.rules.common.Perceptron; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.moa.classifiers.rules.core.attributeclassobservers.FIMTDDNumericAttributeClassLimitObserver; +import org.apache.samoa.moa.classifiers.rules.core.voting.ErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.InverseErrorWeightedVote; +import org.apache.samoa.moa.classifiers.rules.core.voting.UniformWeightedVote; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Model Aggregator Processor (VAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRulesAggregatorProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 6303385725332704251L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesAggregatorProcessor.class); + + private int processorId; + + // Rules & default rule + protected transient List<PassiveRule> ruleSet; + protected transient ActiveRule defaultRule; + protected transient int ruleNumberID; + protected transient double[] statistics; + + // SAMOA Stream + private Stream statisticsStream; + private Stream resultStream; + + // Options + protected int pageHinckleyThreshold; + protected double pageHinckleyAlpha; + protected boolean driftDetection; + protected int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + protected boolean constantLearningRatioDecay; + protected double learningRatio; + + protected double splitConfidence; + protected double tieThreshold; + protected int gracePeriod; + + protected boolean noAnomalyDetection; + protected double multivariateAnomalyProbabilityThreshold; + protected double univariateAnomalyprobabilityThreshold; + protected int anomalyNumInstThreshold; + + protected boolean unorderedRules; + + protected FIMTDDNumericAttributeClassLimitObserver numericObserver; + protected int voteType; + + /* + * Constructor + */ + public AMRulesAggregatorProcessor(Builder builder) { + this.pageHinckleyThreshold = builder.pageHinckleyThreshold; + this.pageHinckleyAlpha = builder.pageHinckleyAlpha; + this.driftDetection = builder.driftDetection; + this.predictionFunction = builder.predictionFunction; + this.constantLearningRatioDecay = builder.constantLearningRatioDecay; + this.learningRatio = builder.learningRatio; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + + this.noAnomalyDetection = builder.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = builder.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = builder.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = builder.anomalyNumInstThreshold; + this.unorderedRules = builder.unorderedRules; + + this.numericObserver = builder.numericObserver; + this.voteType = builder.voteType; + } + + /* + * Process + */ + @Override + public boolean process(ContentEvent event) { + if (event instanceof InstanceContentEvent) { + InstanceContentEvent instanceEvent = (InstanceContentEvent) event; + this.processInstanceEvent(instanceEvent); + } + else if (event instanceof PredicateContentEvent) { + this.updateRuleSplitNode((PredicateContentEvent) event); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent rce = (RuleContentEvent) event; + if (rce.isRemoving()) { + this.removeRule(rce.getRuleNumberID()); + } + } + + return true; + } + + // Merge predict and train so we only check for covering rules one time + private void processInstanceEvent(InstanceContentEvent instanceEvent) { + Instance instance = instanceEvent.getInstance(); + boolean predictionCovered = false; + boolean trainingCovered = false; + boolean continuePrediction = instanceEvent.isTesting(); + boolean continueTraining = instanceEvent.isTraining(); + + ErrorWeightedVote errorWeightedVote = newErrorWeightedVote(); + Iterator<PassiveRule> ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + if (!continuePrediction && !continueTraining) + break; + + PassiveRule rule = ruleIterator.next(); + + if (rule.isCovering(instance) == true) { + predictionCovered = true; + + if (continuePrediction) { + double[] vote = rule.getPrediction(instance); + double error = rule.getCurrentError(); + errorWeightedVote.addVote(vote, error); + if (!this.unorderedRules) + continuePrediction = false; + } + + if (continueTraining) { + if (!isAnomaly(instance, rule)) { + trainingCovered = true; + rule.updateStatistics(instance); + // Send instance to statistics PIs + sendInstanceToRule(instance, rule.getRuleNumberID()); + + if (!this.unorderedRules) + continueTraining = false; + } + } + } + } + + if (predictionCovered) { + // Combined prediction + ResultContentEvent rce = newResultContentEvent(errorWeightedVote.computeWeightedVote(), instanceEvent); + resultStream.put(rce); + } + else if (instanceEvent.isTesting()) { + // predict with default rule + double[] vote = defaultRule.getPrediction(instance); + ResultContentEvent rce = newResultContentEvent(vote, instanceEvent); + resultStream.put(rce); + } + + if (!trainingCovered && instanceEvent.isTraining()) { + // train default rule with this instance + defaultRule.updateStatistics(instance); + if (defaultRule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (defaultRule.tryToExpand(this.splitConfidence, this.tieThreshold) == true) { + ActiveRule newDefaultRule = newRule(defaultRule.getRuleNumberID(), + (RuleActiveRegressionNode) defaultRule.getLearningNode(), + ((RuleActiveRegressionNode) defaultRule.getLearningNode()).getStatisticsOtherBranchSplit()); // other branch + defaultRule.split(); + defaultRule.setRuleNumberID(++ruleNumberID); + this.ruleSet.add(new PassiveRule(this.defaultRule)); + // send to statistics PI + sendAddRuleEvent(defaultRule.getRuleNumberID(), this.defaultRule); + defaultRule = newDefaultRule; + } + } + } + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and its prediction result. + * + * @param prediction + * The predicted class label from the decision tree model. + * @param inEvent + * The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent) { + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + public ErrorWeightedVote newErrorWeightedVote() { + if (voteType == 1) + return new UniformWeightedVote(); + return new InverseErrorWeightedVote(); + } + + /** + * Method to verify if the instance is an anomaly. + * + * @param instance + * @param rule + * @return + */ + private boolean isAnomaly(Instance instance, LearningRule rule) { + // AMRUles is equipped with anomaly detection. If on, compute the anomaly + // value. + boolean isAnomaly = false; + if (this.noAnomalyDetection == false) { + if (rule.getInstancesSeen() >= this.anomalyNumInstThreshold) { + isAnomaly = rule.isAnomaly(instance, + this.univariateAnomalyprobabilityThreshold, + this.multivariateAnomalyProbabilityThreshold, + this.anomalyNumInstThreshold); + } + } + return isAnomaly; + } + + /* + * Create new rules + */ + private ActiveRule newRule(int ID, RuleActiveRegressionNode node, double[] statistics) { + ActiveRule r = newRule(ID); + + if (node != null) + { + if (node.getPerceptron() != null) + { + r.getLearningNode().setPerceptron(new Perceptron(node.getPerceptron())); + r.getLearningNode().getPerceptron().setLearningRatio(this.learningRatio); + } + if (statistics == null) + { + double mean; + if (node.getNodeStatistics().getValue(0) > 0) { + mean = node.getNodeStatistics().getValue(1) / node.getNodeStatistics().getValue(0); + r.getLearningNode().getTargetMean().reset(mean, 1); + } + } + } + if (statistics != null && ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean() != null) + { + double mean; + if (statistics[0] > 0) { + mean = statistics[1] / statistics[0]; + ((RuleActiveRegressionNode) r.getLearningNode()).getTargetMean().reset(mean, (long) statistics[0]); + } + } + return r; + } + + private ActiveRule newRule(int ID) { + ActiveRule r = new ActiveRule.Builder(). + threshold(this.pageHinckleyThreshold). + alpha(this.pageHinckleyAlpha). + changeDetection(this.driftDetection). + predictionFunction(this.predictionFunction). + statistics(new double[3]). + learningRatio(this.learningRatio). + numericObserver(numericObserver). + id(ID).build(); + return r; + } + + /* + * Add predicate/RuleSplitNode for a rule + */ + private void updateRuleSplitNode(PredicateContentEvent pce) { + int ruleID = pce.getRuleNumberID(); + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + if (pce.getRuleSplitNode() != null) + rule.nodeListAdd(pce.getRuleSplitNode()); + if (pce.getLearningNode() != null) + rule.setLearningNode(pce.getLearningNode()); + } + } + } + + /* + * Remove rule + */ + private void removeRule(int ruleID) { + for (PassiveRule rule : ruleSet) { + if (rule.getRuleNumberID() == ruleID) { + ruleSet.remove(rule); + break; + } + } + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.statistics = new double[] { 0.0, 0, 0 }; + this.ruleNumberID = 0; + this.defaultRule = newRule(++this.ruleNumberID); + + this.ruleSet = new LinkedList<PassiveRule>(); + } + + /* + * Clone processor + */ + @Override + public Processor newProcessor(Processor p) { + AMRulesAggregatorProcessor oldProcessor = (AMRulesAggregatorProcessor) p; + Builder builder = new Builder(oldProcessor); + AMRulesAggregatorProcessor newProcessor = builder.build(); + newProcessor.resultStream = oldProcessor.resultStream; + newProcessor.statisticsStream = oldProcessor.statisticsStream; + return newProcessor; + } + + /* + * Send events + */ + private void sendInstanceToRule(Instance instance, int ruleID) { + AssignmentContentEvent ace = new AssignmentContentEvent(ruleID, instance); + this.statisticsStream.put(ace); + } + + private void sendAddRuleEvent(int ruleID, ActiveRule rule) { + RuleContentEvent rce = new RuleContentEvent(ruleID, rule, false); + this.statisticsStream.put(rce); + } + + /* + * Output streams + */ + public void setStatisticsStream(Stream statisticsStream) { + this.statisticsStream = statisticsStream; + } + + public Stream getStatisticsStream() { + return this.statisticsStream; + } + + public void setResultStream(Stream resultStream) { + this.resultStream = resultStream; + } + + public Stream getResultStream() { + return this.resultStream; + } + + /* + * Others + */ + public boolean isRandomizable() { + return true; + } + + /* + * Builder + */ + public static class Builder { + private int pageHinckleyThreshold; + private double pageHinckleyAlpha; + private boolean driftDetection; + private int predictionFunction; // Adaptive=0 Perceptron=1 TargetMean=2 + private boolean constantLearningRatioDecay; + private double learningRatio; + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private boolean noAnomalyDetection; + private double multivariateAnomalyProbabilityThreshold; + private double univariateAnomalyprobabilityThreshold; + private int anomalyNumInstThreshold; + + private boolean unorderedRules; + + private FIMTDDNumericAttributeClassLimitObserver numericObserver; + private int voteType; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRulesAggregatorProcessor processor) { + this.pageHinckleyThreshold = processor.pageHinckleyThreshold; + this.pageHinckleyAlpha = processor.pageHinckleyAlpha; + this.driftDetection = processor.driftDetection; + this.predictionFunction = processor.predictionFunction; + this.constantLearningRatioDecay = processor.constantLearningRatioDecay; + this.learningRatio = processor.learningRatio; + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + + this.noAnomalyDetection = processor.noAnomalyDetection; + this.multivariateAnomalyProbabilityThreshold = processor.multivariateAnomalyProbabilityThreshold; + this.univariateAnomalyprobabilityThreshold = processor.univariateAnomalyprobabilityThreshold; + this.anomalyNumInstThreshold = processor.anomalyNumInstThreshold; + this.unorderedRules = processor.unorderedRules; + + this.numericObserver = processor.numericObserver; + this.voteType = processor.voteType; + } + + public Builder threshold(int threshold) { + this.pageHinckleyThreshold = threshold; + return this; + } + + public Builder alpha(double alpha) { + this.pageHinckleyAlpha = alpha; + return this; + } + + public Builder changeDetection(boolean changeDetection) { + this.driftDetection = changeDetection; + return this; + } + + public Builder predictionFunction(int predictionFunction) { + this.predictionFunction = predictionFunction; + return this; + } + + public Builder constantLearningRatioDecay(boolean constantDecay) { + this.constantLearningRatioDecay = constantDecay; + return this; + } + + public Builder learningRatio(double learningRatio) { + this.learningRatio = learningRatio; + return this; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder noAnomalyDetection(boolean noAnomalyDetection) { + this.noAnomalyDetection = noAnomalyDetection; + return this; + } + + public Builder multivariateAnomalyProbabilityThreshold(double mAnomalyThreshold) { + this.multivariateAnomalyProbabilityThreshold = mAnomalyThreshold; + return this; + } + + public Builder univariateAnomalyProbabilityThreshold(double uAnomalyThreshold) { + this.univariateAnomalyprobabilityThreshold = uAnomalyThreshold; + return this; + } + + public Builder anomalyNumberOfInstancesThreshold(int anomalyNumInstThreshold) { + this.anomalyNumInstThreshold = anomalyNumInstThreshold; + return this; + } + + public Builder unorderedRules(boolean unorderedRules) { + this.unorderedRules = unorderedRules; + return this; + } + + public Builder numericObserver(FIMTDDNumericAttributeClassLimitObserver numericObserver) { + this.numericObserver = numericObserver; + return this; + } + + public Builder voteType(int voteType) { + this.voteType = voteType; + return this; + } + + public AMRulesAggregatorProcessor build() { + return new AMRulesAggregatorProcessor(this); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java new file mode 100644 index 0000000..86fae3c --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AMRulesStatisticsProcessor.java @@ -0,0 +1,219 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; +import org.apache.samoa.learners.classifiers.rules.common.RuleActiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RuleSplitNode; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Learner Processor (VAMR). + * + * @author Anh Thu Vu + * + */ +public class AMRulesStatisticsProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = 5268933189695395573L; + + private static final Logger logger = + LoggerFactory.getLogger(AMRulesStatisticsProcessor.class); + + private int processorId; + + private transient List<ActiveRule> ruleSet; + + private Stream outputStream; + + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + public AMRulesStatisticsProcessor(Builder builder) { + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + this.frequency = builder.frequency; + } + + @Override + public boolean process(ContentEvent event) { + if (event instanceof AssignmentContentEvent) { + + AssignmentContentEvent attrContentEvent = (AssignmentContentEvent) event; + trainRuleOnInstance(attrContentEvent.getRuleNumberID(), attrContentEvent.getInstance()); + } + else if (event instanceof RuleContentEvent) { + RuleContentEvent ruleContentEvent = (RuleContentEvent) event; + if (!ruleContentEvent.isRemoving()) { + addRule(ruleContentEvent.getRule()); + } + } + + return false; + } + + /* + * Process input instances + */ + private void trainRuleOnInstance(int ruleID, Instance instance) { + Iterator<ActiveRule> ruleIterator = this.ruleSet.iterator(); + while (ruleIterator.hasNext()) { + ActiveRule rule = ruleIterator.next(); + if (rule.getRuleNumberID() == ruleID) { + // Check (again) for coverage + // Skip anomaly check as Aggregator's perceptron should be well-updated + if (rule.isCovering(instance) == true) { + double error = rule.computeError(instance); // Use adaptive mode error + boolean changeDetected = ((RuleActiveRegressionNode) rule.getLearningNode()).updateChangeDetection(error); + if (changeDetected == true) { + ruleIterator.remove(); + + this.sendRemoveRuleEvent(ruleID); + } else { + rule.updateStatistics(instance); + if (rule.getInstancesSeen() % this.gracePeriod == 0.0) { + if (rule.tryToExpand(this.splitConfidence, this.tieThreshold)) { + rule.split(); + + // expanded: update Aggregator with new/updated predicate + this.sendPredicate(rule.getRuleNumberID(), rule.getLastUpdatedRuleSplitNode(), + (RuleActiveRegressionNode) rule.getLearningNode()); + } + } + } + } + + return; + } + } + } + + private void sendRemoveRuleEvent(int ruleID) { + RuleContentEvent rce = new RuleContentEvent(ruleID, null, true); + this.outputStream.put(rce); + } + + private void sendPredicate(int ruleID, RuleSplitNode splitNode, RuleActiveRegressionNode learningNode) { + this.outputStream.put(new PredicateContentEvent(ruleID, splitNode, new RulePassiveRegressionNode(learningNode))); + } + + /* + * Process control message (regarding adding or removing rules) + */ + private boolean addRule(ActiveRule rule) { + this.ruleSet.add(rule); + return true; + } + + @Override + public void onCreate(int id) { + this.processorId = id; + this.ruleSet = new LinkedList<ActiveRule>(); + } + + @Override + public Processor newProcessor(Processor p) { + AMRulesStatisticsProcessor oldProcessor = (AMRulesStatisticsProcessor) p; + AMRulesStatisticsProcessor newProcessor = + new AMRulesStatisticsProcessor.Builder(oldProcessor).build(); + + newProcessor.setOutputStream(oldProcessor.outputStream); + return newProcessor; + } + + /* + * Builder + */ + public static class Builder { + private double splitConfidence; + private double tieThreshold; + private int gracePeriod; + + private int frequency; + + private Instances dataset; + + public Builder(Instances dataset) { + this.dataset = dataset; + } + + public Builder(AMRulesStatisticsProcessor processor) { + this.splitConfidence = processor.splitConfidence; + this.tieThreshold = processor.tieThreshold; + this.gracePeriod = processor.gracePeriod; + this.frequency = processor.frequency; + } + + public Builder splitConfidence(double splitConfidence) { + this.splitConfidence = splitConfidence; + return this; + } + + public Builder tieThreshold(double tieThreshold) { + this.tieThreshold = tieThreshold; + return this; + } + + public Builder gracePeriod(int gracePeriod) { + this.gracePeriod = gracePeriod; + return this; + } + + public Builder frequency(int frequency) { + this.frequency = frequency; + return this; + } + + public AMRulesStatisticsProcessor build() { + return new AMRulesStatisticsProcessor(this); + } + } + + /* + * Output stream + */ + public void setOutputStream(Stream stream) { + this.outputStream = stream; + } + + public Stream getOutputStream() { + return this.outputStream; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java new file mode 100644 index 0000000..0c603f8 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/AssignmentContentEvent.java @@ -0,0 +1,74 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.instances.Instance; + +/** + * Forwarded instances from Model Agrregator to Learners/Default Rule Learner. + * + * @author Anh Thu Vu + * + */ +public class AssignmentContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 1031695762172836629L; + + private int ruleNumberID; + private Instance instance; + + public AssignmentContentEvent() { + this(0, null); + } + + public AssignmentContentEvent(int ruleID, Instance instance) { + this.ruleNumberID = ruleID; + this.instance = instance; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } + + public Instance getInstance() { + return this.instance; + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java new file mode 100644 index 0000000..00d7eb8 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/PredicateContentEvent.java @@ -0,0 +1,84 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.RulePassiveRegressionNode; +import org.apache.samoa.learners.classifiers.rules.common.RuleSplitNode; + +/** + * New features (of newly expanded rules) from Learners to Model Aggregators. + * + * @author Anh Thu Vu + * + */ +public class PredicateContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = 7909435830443732451L; + + private int ruleNumberID; + private RuleSplitNode ruleSplitNode; + private RulePassiveRegressionNode learningNode; + + /* + * Constructor + */ + public PredicateContentEvent() { + this(0, null, null); + } + + public PredicateContentEvent(int ruleID, RuleSplitNode ruleSplitNode, RulePassiveRegressionNode learningNode) { + this.ruleNumberID = ruleID; + this.ruleSplitNode = ruleSplitNode; // is this is null: this is for updating learningNode only + this.learningNode = learningNode; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; // N/A + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + + public RuleSplitNode getRuleSplitNode() { + return this.ruleSplitNode; + } + + public RulePassiveRegressionNode getLearningNode() { + return this.learningNode; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java new file mode 100644 index 0000000..5209b52 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/rules/distributed/RuleContentEvent.java @@ -0,0 +1,80 @@ +package org.apache.samoa.learners.classifiers.rules.distributed; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.learners.classifiers.rules.common.ActiveRule; + +/** + * New rule from Model Aggregator/Default Rule Learner to Learners or removed rule from Learner to Model Aggregators. + * + * @author Anh Thu Vu + * + */ +public class RuleContentEvent implements ContentEvent { + + /** + * + */ + private static final long serialVersionUID = -9046390274402894461L; + + private final int ruleNumberID; + private final ActiveRule addingRule; // for removing rule, we only need the rule's ID + private final boolean isRemoving; + + public RuleContentEvent() { + this(0, null, false); + } + + public RuleContentEvent(int ruleID, ActiveRule rule, boolean isRemoving) { + this.ruleNumberID = ruleID; + this.isRemoving = isRemoving; + this.addingRule = rule; + } + + @Override + public String getKey() { + return Integer.toString(this.ruleNumberID); + } + + @Override + public void setKey(String key) { + // do nothing + } + + @Override + public boolean isLastEvent() { + return false; + } + + public int getRuleNumberID() { + return this.ruleNumberID; + } + + public ActiveRule getRule() { + return this.addingRule; + } + + public boolean isRemoving() { + return this.isRemoving; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ActiveLearningNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ActiveLearningNode.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ActiveLearningNode.java new file mode 100644 index 0000000..860b951 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ActiveLearningNode.java @@ -0,0 +1,206 @@ +package org.apache.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.HashMap; +import java.util.Map; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class ActiveLearningNode extends LearningNode { + /** + * + */ + private static final long serialVersionUID = -2892102872646338908L; + private static final Logger logger = LoggerFactory.getLogger(ActiveLearningNode.class); + + private double weightSeenAtLastSplitEvaluation; + + private final Map<Integer, String> attributeContentEventKeys; + + private AttributeSplitSuggestion bestSuggestion; + private AttributeSplitSuggestion secondBestSuggestion; + + private final long id; + private final int parallelismHint; + private int suggestionCtr; + private int thrownAwayInstance; + + private boolean isSplitting; + + ActiveLearningNode(double[] classObservation, int parallelismHint) { + super(classObservation); + this.weightSeenAtLastSplitEvaluation = this.getWeightSeen(); + this.id = VerticalHoeffdingTree.LearningNodeIdGenerator.generate(); + this.attributeContentEventKeys = new HashMap<>(); + this.isSplitting = false; + this.parallelismHint = parallelismHint; + } + + long getId() { + return id; + } + + protected AttributeBatchContentEvent[] attributeBatchContentEvent; + + public AttributeBatchContentEvent[] getAttributeBatchContentEvent() { + return this.attributeBatchContentEvent; + } + + public void setAttributeBatchContentEvent(AttributeBatchContentEvent[] attributeBatchContentEvent) { + this.attributeBatchContentEvent = attributeBatchContentEvent; + } + + @Override + void learnFromInstance(Instance inst, ModelAggregatorProcessor proc) { + // TODO: what statistics should we keep for unused instance? + if (isSplitting) { // currently throw all instance will splitting + this.thrownAwayInstance++; + return; + } + this.observedClassDistribution.addToValue((int) inst.classValue(), + inst.weight()); + // done: parallelize by sending attributes one by one + // TODO: meanwhile, we can try to use the ThreadPool to execute it + // separately + // TODO: parallelize by sending in batch, i.e. split the attributes into + // chunk instead of send the attribute one by one + for (int i = 0; i < inst.numAttributes() - 1; i++) { + int instAttIndex = modelAttIndexToInstanceAttIndex(i, inst); + Integer obsIndex = i; + String key = attributeContentEventKeys.get(obsIndex); + + if (key == null) { + key = this.generateKey(i); + attributeContentEventKeys.put(obsIndex, key); + } + AttributeContentEvent ace = new AttributeContentEvent.Builder( + this.id, i, key) + .attrValue(inst.value(instAttIndex)) + .classValue((int) inst.classValue()) + .weight(inst.weight()) + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + if (this.attributeBatchContentEvent == null) { + this.attributeBatchContentEvent = new AttributeBatchContentEvent[inst.numAttributes() - 1]; + } + if (this.attributeBatchContentEvent[i] == null) { + this.attributeBatchContentEvent[i] = new AttributeBatchContentEvent.Builder( + this.id, i, key) + // .attrValue(inst.value(instAttIndex)) + // .classValue((int) inst.classValue()) + // .weight(inst.weight()] + .isNominal(inst.attribute(instAttIndex).isNominal()) + .build(); + } + this.attributeBatchContentEvent[i].add(ace); + // proc.sendToAttributeStream(ace); + } + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor map) { + return this.observedClassDistribution.getArrayCopy(); + } + + double getWeightSeen() { + return this.observedClassDistribution.sumOfValues(); + } + + void setWeightSeenAtLastSplitEvaluation(double weight) { + this.weightSeenAtLastSplitEvaluation = weight; + } + + double getWeightSeenAtLastSplitEvaluation() { + return this.weightSeenAtLastSplitEvaluation; + } + + void requestDistributedSuggestions(long splitId, ModelAggregatorProcessor modelAggrProc) { + this.isSplitting = true; + this.suggestionCtr = 0; + this.thrownAwayInstance = 0; + + ComputeContentEvent cce = new ComputeContentEvent(splitId, this.id, + this.getObservedClassDistribution()); + modelAggrProc.sendToControlStream(cce); + } + + void addDistributedSuggestions(AttributeSplitSuggestion bestSuggestion, AttributeSplitSuggestion secondBestSuggestion) { + // starts comparing from the best suggestion + if (bestSuggestion != null) { + if ((this.bestSuggestion == null) || (bestSuggestion.compareTo(this.bestSuggestion) > 0)) { + this.secondBestSuggestion = this.bestSuggestion; + this.bestSuggestion = bestSuggestion; + + if (secondBestSuggestion != null) { + + if ((this.secondBestSuggestion == null) || (secondBestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { + this.secondBestSuggestion = secondBestSuggestion; + } + } + } else { + if ((this.secondBestSuggestion == null) || (bestSuggestion.compareTo(this.secondBestSuggestion) > 0)) { + this.secondBestSuggestion = bestSuggestion; + } + } + } + + // TODO: optimize the code to use less memory + this.suggestionCtr++; + } + + boolean isSplitting() { + return this.isSplitting; + } + + void endSplitting() { + this.isSplitting = false; + logger.trace("wasted instance: {}", this.thrownAwayInstance); + this.thrownAwayInstance = 0; + } + + AttributeSplitSuggestion getDistributedBestSuggestion() { + return this.bestSuggestion; + } + + AttributeSplitSuggestion getDistributedSecondBestSuggestion() { + return this.secondBestSuggestion; + } + + boolean isAllSuggestionsCollected() { + return (this.suggestionCtr == this.parallelismHint); + } + + private static int modelAttIndexToInstanceAttIndex(int index, Instance inst) { + return inst.classIndex() > index ? index : index + 1; + } + + private String generateKey(int obsIndex) { + final int prime = 31; + int result = 1; + result = prime * result + (int) (this.id ^ (this.id >>> 32)); + result = prime * result + obsIndex; + return Integer.toString(result); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java new file mode 100644 index 0000000..8973de1 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/AttributeBatchContentEvent.java @@ -0,0 +1,136 @@ +package org.apache.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.LinkedList; +import java.util.List; + +import org.apache.samoa.core.ContentEvent; + +/** + * Attribute Content Event represents the instances that split vertically based on their attribute + * + * @author Arinto Murdopo + * + */ +final class AttributeBatchContentEvent implements ContentEvent { + + private static final long serialVersionUID = 6652815649846676832L; + + private final long learningNodeId; + private final int obsIndex; + private final List<ContentEvent> contentEventList; + private final transient String key; + private final boolean isNominal; + + public AttributeBatchContentEvent() { + learningNodeId = -1; + obsIndex = -1; + contentEventList = new LinkedList<>(); + key = ""; + isNominal = true; + } + + private AttributeBatchContentEvent(Builder builder) { + this.learningNodeId = builder.learningNodeId; + this.obsIndex = builder.obsIndex; + this.contentEventList = new LinkedList<>(); + if (builder.contentEvent != null) { + this.contentEventList.add(builder.contentEvent); + } + this.isNominal = builder.isNominal; + this.key = builder.key; + } + + public void add(ContentEvent contentEvent) { + this.contentEventList.add(contentEvent); + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + // do nothing, maybe useful when we want to reuse the object for + // serialization/deserialization purpose + } + + @Override + public boolean isLastEvent() { + return false; + } + + long getLearningNodeId() { + return this.learningNodeId; + } + + int getObsIndex() { + return this.obsIndex; + } + + public List<ContentEvent> getContentEventList() { + return this.contentEventList; + } + + boolean isNominal() { + return this.isNominal; + } + + static final class Builder { + + // required parameters + private final long learningNodeId; + private final int obsIndex; + private final String key; + + private ContentEvent contentEvent; + private boolean isNominal = false; + + Builder(long id, int obsIndex, String key) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = key; + } + + private Builder(long id, int obsIndex) { + this.learningNodeId = id; + this.obsIndex = obsIndex; + this.key = ""; + } + + Builder contentEvent(ContentEvent contentEvent) { + this.contentEvent = contentEvent; + return this; + } + + Builder isNominal(boolean val) { + this.isNominal = val; + return this; + } + + AttributeBatchContentEvent build() { + return new AttributeBatchContentEvent(this); + } + } + +}
