http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java new file mode 100644 index 0000000..cf7a1b3 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java @@ -0,0 +1,731 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.instances.InstancesHeader; +import com.yahoo.labs.samoa.learners.InstanceContentEvent; +import com.yahoo.labs.samoa.learners.InstancesContentEvent; +import com.yahoo.labs.samoa.learners.ResultContentEvent; +import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import com.yahoo.labs.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.InfoGainSplitCriterion; +import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import com.yahoo.labs.samoa.topology.Stream; + +import static com.yahoo.labs.samoa.moa.core.Utils.maxIndex; + +/** + * Model Aggegator Processor consists of the decision tree model. It connects + * to local-statistic PI via attribute stream and control stream. + * Model-aggregator PI sends the split instances via attribute stream and + * it sends control messages to ask local-statistic PI to perform computation + * via control stream. + * + * Model-aggregator PI sends the classification result via result stream to + * an evaluator PI for classifier or other destination PI. The calculation + * results from local statistic arrive to the model-aggregator PI via + * computation-result stream. + + * @author Arinto Murdopo + * + */ +final class ModelAggregatorProcessor implements Processor { + + private static final long serialVersionUID = -1685875718300564886L; + private static final Logger logger = LoggerFactory.getLogger(ModelAggregatorProcessor.class); + + private int processorId; + + private Node treeRoot; + + private int activeLeafNodeCount; + private int inactiveLeafNodeCount; + private int decisionNodeCount; + private boolean growthAllowed; + + private final Instances dataset; + + //to support concurrent split + private long splitId; + private ConcurrentMap<Long, SplittingNodeInfo> splittingNodes; + private BlockingQueue<Long> timedOutSplittingNodes; + + //available streams + private Stream resultStream; + private Stream attributeStream; + private Stream controlStream; + + private transient ScheduledExecutorService executor; + + private final SplitCriterion splitCriterion; + private final double splitConfidence; + private final double tieThreshold; + private final int gracePeriod; + private final int parallelismHint; + private final long timeOut; + + //private constructor based on Builder pattern + private ModelAggregatorProcessor(Builder builder){ + this.dataset = builder.dataset; + this.splitCriterion = builder.splitCriterion; + this.splitConfidence = builder.splitConfidence; + this.tieThreshold = builder.tieThreshold; + this.gracePeriod = builder.gracePeriod; + this.parallelismHint = builder.parallelismHint; + this.timeOut = builder.timeOut; + this.changeDetector = builder.changeDetector; + + InstancesHeader ih = new InstancesHeader(dataset); + this.setModelContext(ih); + } + + @Override + public boolean process(ContentEvent event) { + + //Poll the blocking queue shared between ModelAggregator and the time-out threads + Long timedOutSplitId = timedOutSplittingNodes.poll(); + if(timedOutSplitId != null){ //time out has been reached! + SplittingNodeInfo splittingNode = splittingNodes.get(timedOutSplitId); + if (splittingNode != null) { + this.splittingNodes.remove(timedOutSplitId); + this.continueAttemptToSplit(splittingNode.activeLearningNode, + splittingNode.foundNode); + + } + + } + + //Receive a new instance from source + if(event instanceof InstancesContentEvent){ + InstancesContentEvent instancesEvent = (InstancesContentEvent) event; + this.processInstanceContentEvent(instancesEvent); + //Send information to local-statistic PI + //for each of the nodes + if (this.foundNodeSet != null){ + for (FoundNode foundNode: this.foundNodeSet ){ + ActiveLearningNode leafNode = (ActiveLearningNode) foundNode.getNode(); + AttributeBatchContentEvent[] abce = leafNode.getAttributeBatchContentEvent(); + if (abce != null) { + for (int i = 0; i< this.dataset.numAttributes() - 1; i++) { + this.sendToAttributeStream(abce[i]); + } + } + leafNode.setAttributeBatchContentEvent(null); + //this.sendToControlStream(event); //split information + //See if we can ask for splits + if(!leafNode.isSplitting()){ + double weightSeen = leafNode.getWeightSeen(); + //check whether it is the time for splitting + if(weightSeen - leafNode.getWeightSeenAtLastSplitEvaluation() >= this.gracePeriod){ + attemptToSplit(leafNode, foundNode); + } + } + } + } + this.foundNodeSet = null; + } else if(event instanceof LocalResultContentEvent){ + LocalResultContentEvent lrce = (LocalResultContentEvent) event; + Long lrceSplitId = lrce.getSplitId(); + SplittingNodeInfo splittingNodeInfo = splittingNodes.get(lrceSplitId); + + if (splittingNodeInfo != null) { // if null, that means + // activeLearningNode has been + // removed by timeout thread + ActiveLearningNode activeLearningNode = splittingNodeInfo.activeLearningNode; + + activeLearningNode.addDistributedSuggestions( + lrce.getBestSuggestion(), + lrce.getSecondBestSuggestion()); + + if (activeLearningNode.isAllSuggestionsCollected()) { + splittingNodeInfo.scheduledFuture.cancel(false); + this.splittingNodes.remove(lrceSplitId); + this.continueAttemptToSplit(activeLearningNode, + splittingNodeInfo.foundNode); + } + } + } + return false; + } + + protected Set<FoundNode> foundNodeSet; + + @Override + public void onCreate(int id) { + this.processorId = id; + + this.activeLeafNodeCount = 0; + this.inactiveLeafNodeCount = 0; + this.decisionNodeCount = 0; + this.growthAllowed = true; + + this.splittingNodes = new ConcurrentHashMap<>(); + this.timedOutSplittingNodes = new LinkedBlockingQueue<>(); + this.splitId = 0; + + //Executor for scheduling time-out threads + this.executor = Executors.newScheduledThreadPool(8); + } + + @Override + public Processor newProcessor(Processor p) { + ModelAggregatorProcessor oldProcessor = (ModelAggregatorProcessor)p; + ModelAggregatorProcessor newProcessor = + new ModelAggregatorProcessor.Builder(oldProcessor).build(); + + newProcessor.setResultStream(oldProcessor.resultStream); + newProcessor.setAttributeStream(oldProcessor.attributeStream); + newProcessor.setControlStream(oldProcessor.controlStream); + return newProcessor; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()); + + sb.append("ActiveLeafNodeCount: ").append(activeLeafNodeCount); + sb.append("InactiveLeafNodeCount: ").append(inactiveLeafNodeCount); + sb.append("DecisionNodeCount: ").append(decisionNodeCount); + sb.append("Growth allowed: ").append(growthAllowed); + return sb.toString(); + } + + void setResultStream(Stream resultStream){ + this.resultStream = resultStream; + } + + void setAttributeStream(Stream attributeStream){ + this.attributeStream = attributeStream; + } + + void setControlStream(Stream controlStream){ + this.controlStream = controlStream; + } + + void sendToAttributeStream(ContentEvent event){ + this.attributeStream.put(event); + } + + void sendToControlStream(ContentEvent event){ + this.controlStream.put(event); + } + + /** + * Helper method to generate new ResultContentEvent based on an instance and + * its prediction result. + * @param prediction The predicted class label from the decision tree model. + * @param inEvent The associated instance content event + * @return ResultContentEvent to be sent into Evaluator PI or other destination PI. + */ + private ResultContentEvent newResultContentEvent(double[] prediction, InstanceContentEvent inEvent){ + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inEvent.getInstance(), + inEvent.getClassId(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + private ResultContentEvent newResultContentEvent(double[] prediction, Instance inst, InstancesContentEvent inEvent){ + ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inst, (int) inst.classValue(), prediction, inEvent.isLastEvent()); + rce.setClassifierIndex(this.processorId); + rce.setEvaluationIndex(inEvent.getEvaluationIndex()); + return rce; + } + + private List<InstancesContentEvent> contentEventList = new LinkedList<>(); + + + /** + * Helper method to process the InstanceContentEvent + * @param instContentEvent + */ + private void processInstanceContentEvent(InstancesContentEvent instContentEvent){ + this.numBatches++; + this.contentEventList.add(instContentEvent); + if (this.numBatches == 1 || this.numBatches > 4){ + this.processInstances(this.contentEventList.remove(0)); + } + + if (instContentEvent.isLastEvent()) { + // drain remaining instances + while (!contentEventList.isEmpty()) { + processInstances(contentEventList.remove(0)); + } + } + + } + + private int numBatches = 0; + + private void processInstances(InstancesContentEvent instContentEvent){ + + Instance[] instances = instContentEvent.getInstances(); + boolean isTesting = instContentEvent.isTesting(); + boolean isTraining= instContentEvent.isTraining(); + for (Instance inst: instances){ + this.processInstance(inst,instContentEvent, isTesting, isTraining); + } + } + + private void processInstance(Instance inst, InstancesContentEvent instContentEvent, boolean isTesting, boolean isTraining){ + inst.setDataset(this.dataset); + //Check the instance whether it is used for testing or training + //boolean testAndTrain = isTraining; //Train after testing + double[] prediction = null; + if (isTesting) { + prediction = getVotesForInstance(inst, false); + this.resultStream.put(newResultContentEvent(prediction, inst, + instContentEvent)); + } + + if (isTraining) { + trainOnInstanceImpl(inst); + if (this.changeDetector != null) { + if (prediction == null) { + prediction = getVotesForInstance(inst); + } + boolean correctlyClassifies = this.correctlyClassifies(inst,prediction); + double oldEstimation = this.changeDetector.getEstimation(); + this.changeDetector.input(correctlyClassifies ? 0 : 1); + if (this.changeDetector.getEstimation() > oldEstimation) { + //Start a new classifier + logger.info("Change detected, resetting the classifier"); + this.resetLearning(); + this.changeDetector.resetLearning(); + } + } + } + } + + private boolean correctlyClassifies(Instance inst, double[] prediction) { + return maxIndex(prediction) == (int) inst.classValue(); + } + + private void resetLearning() { + this.treeRoot = null; + //Remove nodes + FoundNode[] learningNodes = findNodes(); + for (FoundNode learningNode : learningNodes) { + Node node = learningNode.getNode(); + if (node instanceof SplitNode) { + SplitNode splitNode; + splitNode = (SplitNode) node; + for (int i = 0; i < splitNode.numChildren(); i++) { + splitNode.setChild(i, null); + } + } + } + } + + protected FoundNode[] findNodes() { + List<FoundNode> foundList = new LinkedList<>(); + findNodes(this.treeRoot, null, -1, foundList); + return foundList.toArray(new FoundNode[foundList.size()]); + } + + protected void findNodes(Node node, SplitNode parent, + int parentBranch, List<FoundNode> found) { + if (node != null) { + found.add(new FoundNode(node, parent, parentBranch)); + if (node instanceof SplitNode) { + SplitNode splitNode = (SplitNode) node; + for (int i = 0; i < splitNode.numChildren(); i++) { + findNodes(splitNode.getChild(i), splitNode, i, + found); + } + } + } + } + + + /** + * Helper method to get the prediction result. + * The actual prediction result is delegated to the leaf node. + * @param inst + * @return + */ + private double[] getVotesForInstance(Instance inst){ + return getVotesForInstance(inst, false); + } + + private double[] getVotesForInstance(Instance inst, boolean isTraining){ + double[] ret; + FoundNode foundNode = null; + if(this.treeRoot != null){ + foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + Node leafNode = foundNode.getNode(); + if(leafNode == null){ + leafNode = foundNode.getParent(); + } + + ret = leafNode.getClassVotes(inst, this); + } else { + int numClasses = this.dataset.numClasses(); + ret = new double[numClasses]; + + } + + //Training after testing to speed up the process + if (isTraining){ + if(this.treeRoot == null){ + this.treeRoot = newLearningNode(this.parallelismHint); + this.activeLeafNodeCount = 1; + foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + } + trainOnInstanceImpl(foundNode, inst); + } + return ret; + } + + /** + * Helper method that represent training of an instance. Since it is decision tree, + * this method routes the incoming instance into the correct leaf and then update the + * statistic on the found leaf. + * @param inst + */ + private void trainOnInstanceImpl(Instance inst) { + if(this.treeRoot == null){ + this.treeRoot = newLearningNode(this.parallelismHint); + this.activeLeafNodeCount = 1; + + } + FoundNode foundNode = this.treeRoot.filterInstanceToLeaf(inst, null, -1); + trainOnInstanceImpl(foundNode, inst); + } + + private void trainOnInstanceImpl(FoundNode foundNode, Instance inst) { + + Node leafNode = foundNode.getNode(); + + if(leafNode == null){ + leafNode = newLearningNode(this.parallelismHint); + foundNode.getParent().setChild(foundNode.getParentBranch(), leafNode); + activeLeafNodeCount++; + } + + if(leafNode instanceof LearningNode){ + LearningNode learningNode = (LearningNode) leafNode; + learningNode.learnFromInstance(inst, this); + } + if (this.foundNodeSet == null){ + this.foundNodeSet = new HashSet<>(); + } + this.foundNodeSet.add(foundNode); + } + + /** + * Helper method to represent a split attempt + * @param activeLearningNode The corresponding active learning node which will be split + * @param foundNode The data structure to represents the filtering of the instance using the + * tree model. + */ + private void attemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode){ + //Increment the split ID + this.splitId++; + + //Schedule time-out thread + ScheduledFuture<?> timeOutHandler = this.executor.schedule(new AggregationTimeOutHandler(this.splitId, this.timedOutSplittingNodes), + this.timeOut, TimeUnit.SECONDS); + + //Keep track of the splitting node information, so that we can continue the split + //once we receive all local statistic calculation from Local Statistic PI + //this.splittingNodes.put(Long.valueOf(this.splitId), new SplittingNodeInfo(activeLearningNode, foundNode, null)); + this.splittingNodes.put(this.splitId, new SplittingNodeInfo(activeLearningNode, foundNode, timeOutHandler)); + + //Inform Local Statistic PI to perform local statistic calculation + activeLearningNode.requestDistributedSuggestions(this.splitId, this); + } + + + /** + * Helper method to continue the attempt to split once all local calculation results are received. + * @param activeLearningNode The corresponding active learning node which will be split + * @param foundNode The data structure to represents the filtering of the instance using the + * tree model. + */ + private void continueAttemptToSplit(ActiveLearningNode activeLearningNode, FoundNode foundNode){ + AttributeSplitSuggestion bestSuggestion = activeLearningNode.getDistributedBestSuggestion(); + AttributeSplitSuggestion secondBestSuggestion = activeLearningNode.getDistributedSecondBestSuggestion(); + + //compare with null split + double[] preSplitDist = activeLearningNode.getObservedClassDistribution(); + AttributeSplitSuggestion nullSplit = new AttributeSplitSuggestion(null, + new double[0][], this.splitCriterion.getMeritOfSplit( + preSplitDist, + new double[][]{preSplitDist})); + + if((bestSuggestion == null) || (nullSplit.compareTo(bestSuggestion) > 0)){ + secondBestSuggestion = bestSuggestion; + bestSuggestion = nullSplit; + }else{ + if((secondBestSuggestion == null) || (nullSplit.compareTo(secondBestSuggestion) > 0)){ + secondBestSuggestion = nullSplit; + } + } + + boolean shouldSplit = false; + + if(secondBestSuggestion == null){ + shouldSplit = (bestSuggestion != null); + }else{ + double hoeffdingBound = computeHoeffdingBound( + this.splitCriterion.getRangeOfMerit(activeLearningNode.getObservedClassDistribution()), + this.splitConfidence, + activeLearningNode.getWeightSeen()); + + if((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound) + || (hoeffdingBound < tieThreshold)) { + shouldSplit = true; + } + //TODO: add poor attributes removal + } + + SplitNode parent = foundNode.getParent(); + int parentBranch = foundNode.getParentBranch(); + + //split if the Hoeffding bound condition is satisfied + if(shouldSplit){ + + if (bestSuggestion.splitTest != null) { + SplitNode newSplit = new SplitNode(bestSuggestion.splitTest, activeLearningNode.getObservedClassDistribution()); + + for(int i = 0; i < bestSuggestion.numSplits(); i++){ + Node newChild = newLearningNode(bestSuggestion.resultingClassDistributionFromSplit(i), this.parallelismHint); + newSplit.setChild(i, newChild); + } + + this.activeLeafNodeCount--; + this.decisionNodeCount++; + this.activeLeafNodeCount += bestSuggestion.numSplits(); + + if(parent == null){ + this.treeRoot = newSplit; + }else{ + parent.setChild(parentBranch, newSplit); + } + } + //TODO: add check on the model's memory size + } + + //housekeeping + activeLearningNode.endSplitting(); + activeLearningNode.setWeightSeenAtLastSplitEvaluation(activeLearningNode.getWeightSeen()); + } + + /** + * Helper method to deactivate learning node + * @param toDeactivate Active Learning Node that will be deactivated + * @param parent Parent of the soon-to-be-deactivated Active LearningNode + * @param parentBranch the branch index of the node in the parent node + */ + private void deactivateLearningNode(ActiveLearningNode toDeactivate, SplitNode parent, int parentBranch){ + Node newLeaf = new InactiveLearningNode(toDeactivate.getObservedClassDistribution()); + if(parent == null){ + this.treeRoot = newLeaf; + }else{ + parent.setChild(parentBranch, newLeaf); + } + + this.activeLeafNodeCount--; + this.inactiveLeafNodeCount++; + } + + + private LearningNode newLearningNode(int parallelismHint){ + return newLearningNode(new double[0], parallelismHint); + } + + private LearningNode newLearningNode(double[] initialClassObservations, int parallelismHint){ + //for VHT optimization, we need to dynamically instantiate the appropriate ActiveLearningNode + return new ActiveLearningNode(initialClassObservations, parallelismHint); + } + + /** + * Helper method to set the model context, i.e. how many attributes they are and what is the class index + * @param ih + */ + private void setModelContext(InstancesHeader ih){ + //TODO possibly refactored + if ((ih != null) && (ih.classIndex() < 0)) { + throw new IllegalArgumentException( + "Context for a classifier must include a class to learn"); + } + //TODO: check flag for checking whether training has started or not + + //model context is used to describe the model + logger.trace("Model context: {}", ih.toString()); + } + + private static double computeHoeffdingBound(double range, double confidence, double n){ + return Math.sqrt((Math.pow(range, 2.0) * Math.log(1.0/confidence)) / (2.0*n)); + } + + /** + * AggregationTimeOutHandler is a class to support time-out feature while waiting for local computation results + * from the local statistic PIs. + * @author Arinto Murdopo + * + */ + static class AggregationTimeOutHandler implements Runnable{ + + private static final Logger logger = LoggerFactory.getLogger(AggregationTimeOutHandler.class); + private final Long splitId; + private final BlockingQueue<Long> toBeSplittedNodes; + + AggregationTimeOutHandler(Long splitId, BlockingQueue<Long> toBeSplittedNodes){ + this.splitId = splitId; + this.toBeSplittedNodes = toBeSplittedNodes; + } + + @Override + public void run() { + logger.debug("Time out is reached. AggregationTimeOutHandler is started."); + try { + toBeSplittedNodes.put(splitId); + } catch (InterruptedException e) { + logger.warn("Interrupted while trying to put the ID into the queue"); + } + logger.debug("AggregationTimeOutHandler is finished."); + } + } + + /** + * SplittingNodeInfo is a class to represents the ActiveLearningNode that is splitting + * @author Arinto Murdopo + * + */ + static class SplittingNodeInfo{ + + private final ActiveLearningNode activeLearningNode; + private final FoundNode foundNode; + private final ScheduledFuture<?> scheduledFuture; + + SplittingNodeInfo(ActiveLearningNode activeLearningNode, FoundNode foundNode, ScheduledFuture<?> scheduledFuture){ + this.activeLearningNode = activeLearningNode; + this.foundNode = foundNode; + this.scheduledFuture = scheduledFuture; + } + } + + protected ChangeDetector changeDetector; + + public ChangeDetector getChangeDetector() { + return this.changeDetector; + } + + public void setChangeDetector(ChangeDetector cd) { + this.changeDetector = cd; + } + + /** + * Builder class to replace constructors with many parameters + * @author Arinto Murdopo + * + */ + static class Builder{ + + //required parameters + private final Instances dataset; + + //default values + private SplitCriterion splitCriterion = new InfoGainSplitCriterion(); + private double splitConfidence = 0.0000001; + private double tieThreshold = 0.05; + private int gracePeriod = 200; + private int parallelismHint = 1; + private long timeOut = 30; + private ChangeDetector changeDetector = null; + + Builder(Instances dataset){ + this.dataset = dataset; + } + + Builder(ModelAggregatorProcessor oldProcessor){ + this.dataset = oldProcessor.dataset; + this.splitCriterion = oldProcessor.splitCriterion; + this.splitConfidence = oldProcessor.splitConfidence; + this.tieThreshold = oldProcessor.tieThreshold; + this.gracePeriod = oldProcessor.gracePeriod; + this.parallelismHint = oldProcessor.parallelismHint; + this.timeOut = oldProcessor.timeOut; + } + + Builder splitCriterion(SplitCriterion splitCriterion){ + this.splitCriterion = splitCriterion; + return this; + } + + Builder splitConfidence(double splitConfidence){ + this.splitConfidence = splitConfidence; + return this; + } + + Builder tieThreshold(double tieThreshold){ + this.tieThreshold = tieThreshold; + return this; + } + + Builder gracePeriod(int gracePeriod){ + this.gracePeriod = gracePeriod; + return this; + } + + Builder parallelismHint(int parallelismHint){ + this.parallelismHint = parallelismHint; + return this; + } + + Builder timeOut(long timeOut){ + this.timeOut = timeOut; + return this; + } + + Builder changeDetector(ChangeDetector changeDetector){ + this.changeDetector = changeDetector; + return this; + } + ModelAggregatorProcessor build(){ + return new ModelAggregatorProcessor(this); + } + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java new file mode 100644 index 0000000..ff9bf5f --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/Node.java @@ -0,0 +1,94 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.core.DoubleVector; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * Abstract class that represents a node in the tree model. + * @author Arinto Murdopo + * + */ +abstract class Node implements java.io.Serializable{ + + private static final long serialVersionUID = 4008521239214180548L; + + protected final DoubleVector observedClassDistribution; + + /** + * Method to route/filter an instance into its corresponding leaf. This method will be + * invoked recursively. + * @param inst Instance to be routed + * @param parent Parent of the current node + * @param parentBranch The index of the current node in the parent + * @return FoundNode which is the data structure to represent the resulting leaf. + */ + abstract FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch); + + /** + * Method to return the predicted class of the instance based on the statistic + * inside the node. + * + * @param inst To-be-predicted instance + * @param map ModelAggregatorProcessor + * @return The prediction result in the form of class distribution + */ + abstract double[] getClassVotes(Instance inst, ModelAggregatorProcessor map); + + /** + * Method to check whether the node is a leaf node or not. + * @return Boolean flag to indicate whether the node is a leaf or not + */ + abstract boolean isLeaf(); + + + /** + * Constructor of the tree node + * @param classObservation distribution of the observed classes. + */ + protected Node(double[] classObservation){ + this.observedClassDistribution = new DoubleVector(classObservation); + } + + /** + * Getter method for the class distribution + * @return Observed class distribution + */ + protected double[] getObservedClassDistribution() { + return this.observedClassDistribution.getArrayCopy(); + } + + /** + * A method to check whether the class distribution only consists of one class or not. + * @return Flag whether class distribution is pure or not. + */ + protected boolean observedClassDistributionIsPure(){ + return (observedClassDistribution.numNonZeroEntries() < 2); + } + + protected void describeSubtree(ModelAggregatorProcessor modelAggrProc, StringBuilder out, int indent){ + //TODO: implement method to gracefully define the tree + } + + //TODO: calculate promise for limiting the model based on the memory size + //double calculatePromise(); +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java new file mode 100644 index 0000000..fd93db1 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/SplitNode.java @@ -0,0 +1,108 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest; +import com.yahoo.labs.samoa.moa.core.AutoExpandVector; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * SplitNode represents the node that contains one or more questions in the decision tree model, + * in order to route the instances into the correct leaf. + * @author Arinto Murdopo + * + */ +public class SplitNode extends Node { + + private static final long serialVersionUID = -7380795529928485792L; + + private final AutoExpandVector<Node> children; + protected final InstanceConditionalTest splitTest; + + public SplitNode(InstanceConditionalTest splitTest, + double[] classObservation) { + super(classObservation); + this.children = new AutoExpandVector<>(); + this.splitTest = splitTest; + } + + @Override + FoundNode filterInstanceToLeaf(Instance inst, SplitNode parent, int parentBranch) { + int childIndex = instanceChildIndex(inst); + if(childIndex >= 0){ + Node child = getChild(childIndex); + if(child != null){ + return child.filterInstanceToLeaf(inst, this, childIndex); + } + return new FoundNode(null, this, childIndex); + } + return new FoundNode(this, parent, parentBranch); + } + + @Override + boolean isLeaf() { + return false; + } + + @Override + double[] getClassVotes(Instance inst, ModelAggregatorProcessor vht) { + return this.observedClassDistribution.getArrayCopy(); + } + + /** + * Method to return the number of children of this split node + * @return number of children + */ + int numChildren(){ + return this.children.size(); + } + + /** + * Method to set the children in a specific index of the SplitNode with the appropriate child + * @param index Index of the child in the SplitNode + * @param child The child node + */ + void setChild(int index, Node child){ + if ((this.splitTest.maxBranches() >= 0) + && (index >= this.splitTest.maxBranches())) { + throw new IndexOutOfBoundsException(); + } + this.children.set(index, child); + } + + /** + * Method to get the child node given the index + * @param index The child node index + * @return The child node in the given index + */ + Node getChild(int index){ + return this.children.get(index); + } + + /** + * Method to route the instance using this split node + * @param inst The routed instance + * @return The index of the branch where the instance is routed + */ + int instanceChildIndex(Instance inst){ + return this.splitTest.branchForInstance(inst); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java new file mode 100644 index 0000000..e8ccce7 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/classifiers/trees/VerticalHoeffdingTree.java @@ -0,0 +1,185 @@ +package com.yahoo.labs.samoa.learners.classifiers.trees; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.FlagOption; +import com.github.javacliparser.FloatOption; +import com.github.javacliparser.IntOption; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.AdaptiveLearner; +import com.yahoo.labs.samoa.learners.ClassificationLearner; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.AttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.DiscreteAttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.NumericAttributeClassObserver; +import com.yahoo.labs.samoa.moa.classifiers.core.driftdetection.ChangeDetector; +import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import com.yahoo.labs.samoa.topology.Stream; +import com.yahoo.labs.samoa.topology.TopologyBuilder; + +/** + * Vertical Hoeffding Tree. + * <p/> + * Vertical Hoeffding Tree (VHT) classifier is a distributed classifier that + * utilizes vertical parallelism on top of Very Fast Decision Tree (VFDT) + * classifier. + * + * @author Arinto Murdopo + */ +public final class VerticalHoeffdingTree implements ClassificationLearner, AdaptiveLearner, Configurable { + + private static final long serialVersionUID = -4937416312929984057L; + + public ClassOption numericEstimatorOption = new ClassOption("numericEstimator", + 'n', "Numeric estimator to use.", NumericAttributeClassObserver.class, + "GaussianNumericAttributeClassObserver"); + + public ClassOption nominalEstimatorOption = new ClassOption("nominalEstimator", + 'd', "Nominal estimator to use.", DiscreteAttributeClassObserver.class, + "NominalAttributeClassObserver"); + + public ClassOption splitCriterionOption = new ClassOption("splitCriterion", + 's', "Split criterion to use.", SplitCriterion.class, + "InfoGainSplitCriterion"); + + public FloatOption splitConfidenceOption = new FloatOption( + "splitConfidence", + 'c', + "The allowable error in split decision, values closer to 0 will take longer to decide.", + 0.0000001, 0.0, 1.0); + + public FloatOption tieThresholdOption = new FloatOption("tieThreshold", + 't', "Threshold below which a split will be forced to break ties.", + 0.05, 0.0, 1.0); + + public IntOption gracePeriodOption = new IntOption( + "gracePeriod", + 'g', + "The number of instances a leaf should observe between split attempts.", + 200, 0, Integer.MAX_VALUE); + + public IntOption parallelismHintOption = new IntOption( + "parallelismHint", + 'p', + "The number of local statistics PI to do distributed computation", + 1, 1, Integer.MAX_VALUE); + + public IntOption timeOutOption = new IntOption( + "timeOut", + 'o', + "The duration to wait all distributed computation results from local statistics PI", + 30, 1, Integer.MAX_VALUE); + + public FlagOption binarySplitsOption = new FlagOption("binarySplits", 'b', + "Only allow binary splits."); + + private Stream resultStream; + + private FilterProcessor filterProc; + + @Override + public void init(TopologyBuilder topologyBuilder, Instances dataset, int parallelism) { + + this.filterProc = new FilterProcessor.Builder(dataset) + .build(); + topologyBuilder.addProcessor(filterProc, parallelism); + + Stream filterStream = topologyBuilder.createStream(filterProc); + this.filterProc.setOutputStream(filterStream); + + + ModelAggregatorProcessor modelAggrProc = new ModelAggregatorProcessor.Builder(dataset) + .splitCriterion((SplitCriterion) this.splitCriterionOption.getValue()) + .splitConfidence(splitConfidenceOption.getValue()) + .tieThreshold(tieThresholdOption.getValue()) + .gracePeriod(gracePeriodOption.getValue()) + .parallelismHint(parallelismHintOption.getValue()) + .timeOut(timeOutOption.getValue()) + .changeDetector(this.getChangeDetector()) + .build(); + + topologyBuilder.addProcessor(modelAggrProc, parallelism); + + topologyBuilder.connectInputShuffleStream(filterStream, modelAggrProc); + + this.resultStream = topologyBuilder.createStream(modelAggrProc); + modelAggrProc.setResultStream(resultStream); + + Stream attributeStream = topologyBuilder.createStream(modelAggrProc); + modelAggrProc.setAttributeStream(attributeStream); + + Stream controlStream = topologyBuilder.createStream(modelAggrProc); + modelAggrProc.setControlStream(controlStream); + + LocalStatisticsProcessor locStatProc = new LocalStatisticsProcessor.Builder() + .splitCriterion((SplitCriterion) this.splitCriterionOption.getValue()) + .binarySplit(binarySplitsOption.isSet()) + .nominalClassObserver((AttributeClassObserver) this.nominalEstimatorOption.getValue()) + .numericClassObserver((AttributeClassObserver) this.numericEstimatorOption.getValue()) + .build(); + + topologyBuilder.addProcessor(locStatProc, parallelismHintOption.getValue()); + topologyBuilder.connectInputKeyStream(attributeStream, locStatProc); + topologyBuilder.connectInputAllStream(controlStream, locStatProc); + + Stream computeStream = topologyBuilder.createStream(locStatProc); + + locStatProc.setComputationResultStream(computeStream); + topologyBuilder.connectInputAllStream(computeStream, modelAggrProc); + } + + @Override + public Processor getInputProcessor() { + return this.filterProc; + } + + @Override + public Set<Stream> getResultStreams() { + return ImmutableSet.of(this.resultStream); + } + + protected ChangeDetector changeDetector; + + @Override + public ChangeDetector getChangeDetector() { + return this.changeDetector; + } + + @Override + public void setChangeDetector(ChangeDetector cd) { + this.changeDetector = cd; + } + + static class LearningNodeIdGenerator { + + //TODO: add code to warn user of when value reaches Long.MAX_VALUES + private static long id = 0; + + static synchronized long generate() { + return id++; + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java new file mode 100644 index 0000000..a0d950b --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClusteringContentEvent.java @@ -0,0 +1,89 @@ +package com.yahoo.labs.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +import net.jcip.annotations.Immutable; + +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.instances.Instance; + +/** + * The Class ClusteringContentEvent. + */ +@Immutable +final public class ClusteringContentEvent implements ContentEvent { + + private static final long serialVersionUID = -7746983521296618922L; + private Instance instance; + private boolean isLast = false; + private String key; + private boolean isSample; + + public ClusteringContentEvent() { + // Necessary for kryo serializer + } + + /** + * Instantiates a new clustering event. + * + * @param index + * the index + * @param instance + * the instance + */ + public ClusteringContentEvent(long index, Instance instance) { + /* + * if (instance != null) { this.instance = new SerializableInstance(instance); } + */ + this.instance = instance; + this.setKey(Long.toString(index)); + } + + @Override + public String getKey() { + return this.key; + } + + @Override + public void setKey(String str) { + this.key = str; + } + + @Override + public boolean isLastEvent() { + return this.isLast; + } + + public void setLast(boolean isLast) { + this.isLast = isLast; + } + + public Instance getInstance() { + return this.instance; + } + + public boolean isSample() { + return isSample; + } + + public void setSample(boolean b) { + this.isSample = b; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java new file mode 100644 index 0000000..057e37b --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/ClustreamClustererAdapter.java @@ -0,0 +1,166 @@ +package com.yahoo.labs.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * License + */ +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.instances.InstancesHeader; +import com.yahoo.labs.samoa.moa.cluster.Clustering; +import com.yahoo.labs.samoa.moa.clusterers.clustream.Clustream; + +/** + * + * Base class for adapting Clustream clusterer. + * + */ +public class ClustreamClustererAdapter implements LocalClustererAdapter, Configurable { + + /** + * + */ + private static final long serialVersionUID = 4372366401338704353L; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Clusterer to train.", com.yahoo.labs.samoa.moa.clusterers.Clusterer.class, Clustream.class.getName()); + /** + * The learner. + */ + protected com.yahoo.labs.samoa.moa.clusterers.Clusterer learner; + + /** + * The is init. + */ + protected Boolean isInit; + + /** + * The dataset. + */ + protected Instances dataset; + + @Override + public void setDataset(Instances dataset) { + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner the learner + * @param dataset the dataset + */ + public ClustreamClustererAdapter(com.yahoo.labs.samoa.moa.clusterers.Clusterer learner, Instances dataset) { + this.learner = learner.copy(); + this.isInit = false; + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner the learner + * @param dataset the dataset + */ + public ClustreamClustererAdapter() { + this.learner = ((com.yahoo.labs.samoa.moa.clusterers.Clusterer) this.learnerOption.getValue()).copy(); + this.isInit = false; + //this.dataset = dataset; + } + + /** + * Creates a new learner object. + * + * @return the learner + */ + @Override + public ClustreamClustererAdapter create() { + ClustreamClustererAdapter l = new ClustreamClustererAdapter(learner, dataset); + if (dataset == null) { + System.out.println("dataset null while creating"); + } + return l; + } + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst the instance to be used for training + */ + @Override + public void trainOnInstance(Instance inst) { + if (this.isInit == false) { + this.isInit = true; + InstancesHeader instances = new InstancesHeader(dataset); + this.learner.setModelContext(instances); + this.learner.prepareForUse(); + } + if (inst.weight() > 0) { + inst.setDataset(dataset); + learner.trainOnInstance(inst); + } + } + + /** + * Predicts the class memberships for a given instance. If an instance is + * unclassified, the returned array elements must be all zero. + * + * @param inst the instance to be classified + * @return an array containing the estimated membership probabilities of the + * test instance in each class + */ + @Override + public double[] getVotesForInstance(Instance inst) { + double[] ret; + inst.setDataset(dataset); + if (this.isInit == false) { + ret = new double[dataset.numClasses()]; + } else { + ret = learner.getVotesForInstance(inst); + } + return ret; + } + + /** + * Resets this classifier. It must be similar to starting a new classifier + * from scratch. + * + */ + @Override + public void resetLearning() { + learner.resetLearning(); + } + + public boolean implementsMicroClusterer() { + return this.learner.implementsMicroClusterer(); + } + + public Clustering getMicroClusteringResult() { + return this.learner.getMicroClusteringResult(); + } + + public Instances getDataset() { + return this.dataset; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java new file mode 100644 index 0000000..fedbcfe --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererAdapter.java @@ -0,0 +1,82 @@ +package com.yahoo.labs.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.moa.cluster.Clustering; +import java.io.Serializable; + +/** + * Learner interface for non-distributed learners. + * + * @author abifet + */ +public interface LocalClustererAdapter extends Serializable { + + /** + * Creates a new learner object. + * + * @return the learner + */ + LocalClustererAdapter create(); + + /** + * Predicts the class memberships for a given instance. If an instance is + * unclassified, the returned array elements must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the + * test instance in each class + */ + double[] getVotesForInstance(Instance inst); + + /** + * Resets this classifier. It must be similar to starting a new classifier + * from scratch. + * + */ + void resetLearning(); + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + void trainOnInstance(Instance inst); + + /** + * Sets where to obtain the information of attributes of Instances + * + * @param dataset + * the dataset that contains the information + */ + public void setDataset(Instances dataset); + + public Instances getDataset(); + + public boolean implementsMicroClusterer(); + + public Clustering getMicroClusteringResult(); + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java new file mode 100644 index 0000000..a397539 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/LocalClustererProcessor.java @@ -0,0 +1,191 @@ +package com.yahoo.labs.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * License + */ +import com.yahoo.labs.samoa.evaluation.ClusteringEvaluationContentEvent; +import com.yahoo.labs.samoa.evaluation.ClusteringResultContentEvent; +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.DenseInstance; +import com.yahoo.labs.samoa.instances.Instance; +import com.yahoo.labs.samoa.moa.cluster.Clustering; +import com.yahoo.labs.samoa.moa.core.DataPoint; +import com.yahoo.labs.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +//import weka.core.Instance; + +/** + * The Class LearnerProcessor. + */ +final public class LocalClustererProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -1577910988699148691L; + private static final Logger logger = LoggerFactory + .getLogger(LocalClustererProcessor.class); + private LocalClustererAdapter model; + private Stream outputStream; + private int modelId; + private long instancesCount = 0; + private long sampleFrequency = 1000; + + public long getSampleFrequency() { + return sampleFrequency; + } + + public void setSampleFrequency(long sampleFrequency) { + this.sampleFrequency = sampleFrequency; + } + + /** + * Sets the learner. + * + * @param model the model to set + */ + public void setLearner(LocalClustererAdapter model) { + this.model = model; + } + + /** + * Gets the learner. + * + * @return the model + */ + public LocalClustererAdapter getLearner() { + return model; + } + + /** + * Set the output streams. + * + * @param outputStream the new output stream {@link PredictionCombinerPE}. + */ + public void setOutputStream(Stream outputStream) { + + this.outputStream = outputStream; + } + + /** + * Gets the output stream. + * + * @return the output stream + */ + public Stream getOutputStream() { + return outputStream; + } + + /** + * Gets the instances count. + * + * @return number of observation vectors used in training iteration. + */ + public long getInstancesCount() { + return instancesCount; + } + + /** + * Update stats. + * + * @param event the event + */ + private void updateStats(ContentEvent event) { + Instance instance; + if (event instanceof ClusteringContentEvent){ + //Local Clustering + ClusteringContentEvent ev = (ClusteringContentEvent) event; + instance = ev.getInstance(); + DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey())); + model.trainOnInstance(point); + instancesCount++; + } + + if (event instanceof ClusteringResultContentEvent){ + //Global Clustering + ClusteringResultContentEvent ev = (ClusteringResultContentEvent) event; + Clustering clustering = ev.getClustering(); + + for (int i=0; i<clustering.size(); i++) { + instance = new DenseInstance(1.0,clustering.get(i).getCenter()); + instance.setDataset(model.getDataset()); + DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey())); + model.trainOnInstance(point); + instancesCount++; + } + } + + if (instancesCount % this.sampleFrequency == 0) { + logger.info("Trained model using {} events with classifier id {}", instancesCount, this.modelId); // getId()); + } + } + + /** + * On event. + * + * @param event the event + * @return true, if successful + */ + @Override + public boolean process(ContentEvent event) { + + if (event.isLastEvent() || + (instancesCount > 0 && instancesCount% this.sampleFrequency == 0)) { + if (model.implementsMicroClusterer()) { + + Clustering clustering = model.getMicroClusteringResult(); + + ClusteringResultContentEvent resultEvent = new ClusteringResultContentEvent(clustering, event.isLastEvent()); + + this.outputStream.put(resultEvent); + } + } + + updateStats(event); + return false; + } + + /* (non-Javadoc) + * @see samoa.core.Processor#onCreate(int) + */ + @Override + public void onCreate(int id) { + this.modelId = id; + model = model.create(); + } + + /* (non-Javadoc) + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + LocalClustererProcessor newProcessor = new LocalClustererProcessor(); + LocalClustererProcessor originProcessor = (LocalClustererProcessor) sourceProcessor; + if (originProcessor.getLearner() != null) { + newProcessor.setLearner(originProcessor.getLearner().create()); + } + newProcessor.setOutputStream(originProcessor.getOutputStream()); + return newProcessor; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java new file mode 100644 index 0000000..894a0cc --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java @@ -0,0 +1,97 @@ +package com.yahoo.labs.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.Learner; +import com.yahoo.labs.samoa.topology.Stream; +import com.yahoo.labs.samoa.topology.TopologyBuilder; + +/** + * + * Learner that contain a single learner. + * + */ +public final class SingleLearner implements Learner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private LocalClustererProcessor learnerP; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Learner to train.", LocalClustererAdapter.class, ClustreamClustererAdapter.class.getName()); + + private TopologyBuilder builder; + + private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism){ + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + + protected void setLayout() { + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + + this.builder.addProcessor(learnerP, this.parallelism); + resultStream = this.builder.createStream(learnerP); + + learnerP.setOutputStream(resultStream); + } + + /* (non-Javadoc) + * @see samoa.classifiers.Classifier#getInputProcessingItem() + */ + @Override + public Processor getInputProcessor() { + return learnerP; + } + + /* (non-Javadoc) + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.resultStream); + return streams; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java new file mode 100644 index 0000000..e75a1bd --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java @@ -0,0 +1,98 @@ +package com.yahoo.labs.samoa.learners.clusterers.simple; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * License + */ +import com.yahoo.labs.samoa.core.ContentEvent; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.evaluation.ClusteringEvaluationContentEvent; +import com.yahoo.labs.samoa.learners.clusterers.ClusteringContentEvent; +import com.yahoo.labs.samoa.moa.core.DataPoint; +import com.yahoo.labs.samoa.topology.Stream; + +/** + * The Class ClusteringDistributorPE. + */ +public class ClusteringDistributorProcessor implements Processor { + + private static final long serialVersionUID = -1550901409625192730L; + + private Stream outputStream; + private Stream evaluationStream; + private int numInstances; + + public Stream getOutputStream() { + return outputStream; + } + + public void setOutputStream(Stream outputStream) { + this.outputStream = outputStream; + } + + public Stream getEvaluationStream() { + return evaluationStream; + } + + public void setEvaluationStream(Stream evaluationStream) { + this.evaluationStream = evaluationStream; + } + + /** + * Process event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + // distinguish between ClusteringContentEvent and ClusteringEvaluationContentEvent + if (event instanceof ClusteringContentEvent) { + ClusteringContentEvent cce = (ClusteringContentEvent) event; + outputStream.put(event); + if (cce.isSample()) { + evaluationStream.put(new ClusteringEvaluationContentEvent(null, new DataPoint(cce.getInstance(), numInstances++), cce.isLastEvent())); + } + } else if (event instanceof ClusteringEvaluationContentEvent) { + evaluationStream.put(event); + } + return true; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + ClusteringDistributorProcessor newProcessor = new ClusteringDistributorProcessor(); + ClusteringDistributorProcessor originProcessor = (ClusteringDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStream() != null) + newProcessor.setOutputStream(originProcessor.getOutputStream()); + if (originProcessor.getEvaluationStream() != null) + newProcessor.setEvaluationStream(originProcessor.getEvaluationStream()); + return newProcessor; + } + + public void onCreate(int id) { + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java new file mode 100644 index 0000000..d924733 --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java @@ -0,0 +1,118 @@ +package com.yahoo.labs.samoa.learners.clusterers.simple; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2013 Yahoo! Inc. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; +import java.util.Set; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; +import com.yahoo.labs.samoa.core.Processor; +import com.yahoo.labs.samoa.instances.Instances; +import com.yahoo.labs.samoa.learners.Learner; +import com.yahoo.labs.samoa.learners.clusterers.*; +import com.yahoo.labs.samoa.topology.ProcessingItem; +import com.yahoo.labs.samoa.topology.Stream; +import com.yahoo.labs.samoa.topology.TopologyBuilder; + +/** + * + * Learner that contain a single learner. + * + */ +public final class DistributedClusterer implements Learner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', "Clusterer to use.", LocalClustererAdapter.class, + ClustreamClustererAdapter.class.getName()); + + public IntOption paralellismOption = new IntOption("paralellismOption", 'P', "The paralellism level for concurrent processes", 2, 1, Integer.MAX_VALUE); + + private TopologyBuilder builder; + +// private ClusteringDistributorProcessor distributorP; + private LocalClustererProcessor learnerP; + +// private Stream distributorToLocalStream; + private Stream localToGlobalStream; + +// private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; +// this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + // Distributor +// distributorP = new ClusteringDistributorProcessor(); +// this.builder.addProcessor(distributorP, parallelism); +// distributorToLocalStream = this.builder.createStream(distributorP); +// distributorP.setOutputStream(distributorToLocalStream); +// distributorToGlobalStream = this.builder.createStream(distributorP); + + // Local Clustering + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + builder.addProcessor(learnerP, this.paralellismOption.getValue()); + localToGlobalStream = this.builder.createStream(learnerP); + learnerP.setOutputStream(localToGlobalStream); + + // Global Clustering + LocalClustererProcessor globalClusteringCombinerP = new LocalClustererProcessor(); + LocalClustererAdapter globalLearner = (LocalClustererAdapter) this.learnerOption.getValue(); + globalLearner.setDataset(this.dataset); + globalClusteringCombinerP.setLearner(learner); + builder.addProcessor(globalClusteringCombinerP, 1); + builder.connectInputAllStream(localToGlobalStream, globalClusteringCombinerP); + + // Output Stream + resultStream = this.builder.createStream(globalClusteringCombinerP); + globalClusteringCombinerP.setOutputStream(resultStream); + } + + @Override + public Processor getInputProcessor() { +// return distributorP; + return learnerP; + } + + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.resultStream); + return streams; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java new file mode 100644 index 0000000..37303ec --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java @@ -0,0 +1,80 @@ +package com.yahoo.labs.samoa.moa; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import com.yahoo.labs.samoa.moa.core.SerializeUtils; +//import moa.core.SizeOf; + +/** + * Abstract MOA Object. All classes that are serializable, copiable, + * can measure its size, and can give a description, extend this class. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public abstract class AbstractMOAObject implements MOAObject { + + @Override + public MOAObject copy() { + return copy(this); + } + + @Override + public int measureByteSize() { + return measureByteSize(this); + } + + /** + * Returns a description of the object. + * + * @return a description of the object + */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + getDescription(sb, 0); + return sb.toString(); + } + + /** + * This method produces a copy of an object. + * + * @param obj object to copy + * @return a copy of the object + */ + public static MOAObject copy(MOAObject obj) { + try { + return (MOAObject) SerializeUtils.copyObject(obj); + } catch (Exception e) { + throw new RuntimeException("Object copy failed.", e); + } + } + + /** + * Gets the memory size of an object. + * + * @param obj object to measure the memory size + * @return the memory size of this object + */ + public static int measureByteSize(MOAObject obj) { + return 0; //(int) SizeOf.fullSizeOf(obj); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/787864b6/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java new file mode 100644 index 0000000..cc26eaa --- /dev/null +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java @@ -0,0 +1,58 @@ +package com.yahoo.labs.samoa.moa; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +/** + * Interface implemented by classes in MOA, so that all are serializable, + * can produce copies of their objects, and can measure its memory size. + * They also give a string description. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public interface MOAObject extends Serializable { + + /** + * Gets the memory size of this object. + * + * @return the memory size of this object + */ + public int measureByteSize(); + + /** + * This method produces a copy of this object. + * + * @return a copy of this object + */ + public MOAObject copy(); + + /** + * Returns a string representation of this object. + * Used in <code>AbstractMOAObject.toString</code> + * to give a string representation of the object. + * + * @param sb the stringbuilder to add the description + * @param indent the number of characters to indent + */ + public void getDescription(StringBuilder sb, int indent); +}
