Repository: incubator-samoa Updated Branches: refs/heads/master f9db1f271 -> 1bd1012af (forced update)
SAMOA-48: Fix for VHT Project: http://git-wip-us.apache.org/repos/asf/incubator-samoa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-samoa/commit/af25e7d1 Tree: http://git-wip-us.apache.org/repos/asf/incubator-samoa/tree/af25e7d1 Diff: http://git-wip-us.apache.org/repos/asf/incubator-samoa/diff/af25e7d1 Branch: refs/heads/master Commit: af25e7d1f464323dcd0e2bb5729dc70e5ad36887 Parents: d454deb Author: Gianmarco De Francisci Morales <[email protected]> Authored: Mon Oct 26 14:20:55 2015 +0200 Committer: Gianmarco De Francisci Morales <[email protected]> Committed: Mon Oct 26 14:20:55 2015 +0200 ---------------------------------------------------------------------- .../trees/ModelAggregatorProcessor.java | 47 ++++++++------------ 1 file changed, 19 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/af25e7d1/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java index 1e79f48..846d8e1 100644 --- a/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java +++ b/samoa-api/src/main/java/org/apache/samoa/learners/classifiers/trees/ModelAggregatorProcessor.java @@ -123,8 +123,7 @@ final class ModelAggregatorProcessor implements Processor { SplittingNodeInfo splittingNode = splittingNodes.get(timedOutSplitId); if (splittingNode != null) { this.splittingNodes.remove(timedOutSplitId); - this.continueAttemptToSplit(splittingNode.activeLearningNode, - splittingNode.foundNode); + this.continueAttemptToSplit(splittingNode.activeLearningNode, splittingNode.foundNode); } @@ -168,15 +167,12 @@ final class ModelAggregatorProcessor implements Processor { // removed by timeout thread ActiveLearningNode activeLearningNode = splittingNodeInfo.activeLearningNode; - activeLearningNode.addDistributedSuggestions( - lrce.getBestSuggestion(), - lrce.getSecondBestSuggestion()); + activeLearningNode.addDistributedSuggestions(lrce.getBestSuggestion(), lrce.getSecondBestSuggestion()); if (activeLearningNode.isAllSuggestionsCollected()) { splittingNodeInfo.scheduledFuture.cancel(false); this.splittingNodes.remove(lrceSplitId); - this.continueAttemptToSplit(activeLearningNode, - splittingNodeInfo.foundNode); + this.continueAttemptToSplit(activeLearningNode, splittingNodeInfo.foundNode); } } } @@ -205,8 +201,7 @@ final class ModelAggregatorProcessor implements Processor { @Override public Processor newProcessor(Processor p) { ModelAggregatorProcessor oldProcessor = (ModelAggregatorProcessor) p; - ModelAggregatorProcessor newProcessor = - new ModelAggregatorProcessor.Builder(oldProcessor).build(); + ModelAggregatorProcessor newProcessor = new ModelAggregatorProcessor.Builder(oldProcessor).build(); newProcessor.setResultStream(oldProcessor.resultStream); newProcessor.setAttributeStream(oldProcessor.attributeStream); @@ -264,8 +259,13 @@ final class ModelAggregatorProcessor implements Processor { } private ResultContentEvent newResultContentEvent(double[] prediction, Instance inst, InstancesContentEvent inEvent) { + boolean isLastEvent = false; + if (inEvent.isLastEvent()) { + Instance[] tmp = inEvent.getInstances(); + isLastEvent = inst == tmp[tmp.length - 1]; // only set LastEvent on the last instance in the mini-batch + } ResultContentEvent rce = new ResultContentEvent(inEvent.getInstanceIndex(), inst, (int) inst.classValue(), - prediction, inEvent.isLastEvent()); + prediction, isLastEvent); rce.setClassifierIndex(this.processorId); rce.setEvaluationIndex(inEvent.getEvaluationIndex()); return rce; @@ -314,8 +314,7 @@ final class ModelAggregatorProcessor implements Processor { double[] prediction = null; if (isTesting) { prediction = getVotesForInstance(inst, false); - this.resultStream.put(newResultContentEvent(prediction, inst, - instContentEvent)); + this.resultStream.put(newResultContentEvent(prediction, inst, instContentEvent)); } if (isTraining) { @@ -363,15 +362,13 @@ final class ModelAggregatorProcessor implements Processor { return foundList.toArray(new FoundNode[foundList.size()]); } - protected void findNodes(Node node, SplitNode parent, - int parentBranch, List<FoundNode> found) { + protected void findNodes(Node node, SplitNode parent, int parentBranch, List<FoundNode> found) { if (node != null) { found.add(new FoundNode(node, parent, parentBranch)); if (node instanceof SplitNode) { SplitNode splitNode = (SplitNode) node; for (int i = 0; i < splitNode.numChildren(); i++) { - findNodes(splitNode.getChild(i), splitNode, i, - found); + findNodes(splitNode.getChild(i), splitNode, i, found); } } } @@ -466,8 +463,7 @@ final class ModelAggregatorProcessor implements Processor { // Schedule time-out thread ScheduledFuture<?> timeOutHandler = this.executor.schedule(new AggregationTimeOutHandler(this.splitId, - this.timedOutSplittingNodes), - this.timeOut, TimeUnit.SECONDS); + this.timedOutSplittingNodes), this.timeOut, TimeUnit.SECONDS); // Keep track of the splitting node information, so that we can continue the // split @@ -494,10 +490,8 @@ final class ModelAggregatorProcessor implements Processor { // compare with null split double[] preSplitDist = activeLearningNode.getObservedClassDistribution(); - AttributeSplitSuggestion nullSplit = new AttributeSplitSuggestion(null, - new double[0][], this.splitCriterion.getMeritOfSplit( - preSplitDist, - new double[][] { preSplitDist })); + AttributeSplitSuggestion nullSplit = new AttributeSplitSuggestion(null, new double[0][], + this.splitCriterion.getMeritOfSplit(preSplitDist, new double[][] { preSplitDist })); if ((bestSuggestion == null) || (nullSplit.compareTo(bestSuggestion) > 0)) { secondBestSuggestion = bestSuggestion; @@ -514,12 +508,10 @@ final class ModelAggregatorProcessor implements Processor { shouldSplit = (bestSuggestion != null); } else { double hoeffdingBound = computeHoeffdingBound( - this.splitCriterion.getRangeOfMerit(activeLearningNode.getObservedClassDistribution()), - this.splitConfidence, + this.splitCriterion.getRangeOfMerit(activeLearningNode.getObservedClassDistribution()), this.splitConfidence, activeLearningNode.getWeightSeen()); - if ((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound) - || (hoeffdingBound < tieThreshold)) { + if ((bestSuggestion.merit - secondBestSuggestion.merit > hoeffdingBound) || (hoeffdingBound < tieThreshold)) { shouldSplit = true; } // TODO: add poor attributes removal @@ -597,8 +589,7 @@ final class ModelAggregatorProcessor implements Processor { private void setModelContext(InstancesHeader ih) { // TODO possibly refactored if ((ih != null) && (ih.classIndex() < 0)) { - throw new IllegalArgumentException( - "Context for a classifier must include a class to learn"); + throw new IllegalArgumentException("Context for a classifier must include a class to learn"); } // TODO: check flag for checking whether training has started or not
