http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java index f136996..bcf3c31 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterion.java @@ -27,92 +27,91 @@ import com.yahoo.labs.samoa.moa.options.AbstractOptionHandler; import com.yahoo.labs.samoa.moa.tasks.TaskMonitor; /** - * Class for computing splitting criteria using information gain - * with respect to distributions of class values. - * The split criterion is used as a parameter on + * Class for computing splitting criteria using information gain with respect to + * distributions of class values. The split criterion is used as a parameter on * decision trees and decision stumps. - * + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public class InfoGainSplitCriterion extends AbstractOptionHandler implements - SplitCriterion { + SplitCriterion { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - public FloatOption minBranchFracOption = new FloatOption("minBranchFrac", - 'f', - "Minimum fraction of weight required down at least two branches.", - 0.01, 0.0, 0.5); + public FloatOption minBranchFracOption = new FloatOption("minBranchFrac", + 'f', + "Minimum fraction of weight required down at least two branches.", + 0.01, 0.0, 0.5); - @Override - public double getMeritOfSplit(double[] preSplitDist, - double[][] postSplitDists) { - if (numSubsetsGreaterThanFrac(postSplitDists, this.minBranchFracOption.getValue()) < 2) { - return Double.NEGATIVE_INFINITY; - } - return computeEntropy(preSplitDist) - computeEntropy(postSplitDists); + @Override + public double getMeritOfSplit(double[] preSplitDist, + double[][] postSplitDists) { + if (numSubsetsGreaterThanFrac(postSplitDists, this.minBranchFracOption.getValue()) < 2) { + return Double.NEGATIVE_INFINITY; } + return computeEntropy(preSplitDist) - computeEntropy(postSplitDists); + } - @Override - public double getRangeOfMerit(double[] preSplitDist) { - int numClasses = preSplitDist.length > 2 ? preSplitDist.length : 2; - return Utils.log2(numClasses); - } + @Override + public double getRangeOfMerit(double[] preSplitDist) { + int numClasses = preSplitDist.length > 2 ? preSplitDist.length : 2; + return Utils.log2(numClasses); + } - public static double computeEntropy(double[] dist) { - double entropy = 0.0; - double sum = 0.0; - for (double d : dist) { - if (d > 0.0) { // TODO: how small can d be before log2 overflows? - entropy -= d * Utils.log2(d); - sum += d; - } - } - return sum > 0.0 ? (entropy + sum * Utils.log2(sum)) / sum : 0.0; + public static double computeEntropy(double[] dist) { + double entropy = 0.0; + double sum = 0.0; + for (double d : dist) { + if (d > 0.0) { // TODO: how small can d be before log2 overflows? + entropy -= d * Utils.log2(d); + sum += d; + } } + return sum > 0.0 ? (entropy + sum * Utils.log2(sum)) / sum : 0.0; + } - public static double computeEntropy(double[][] dists) { - double totalWeight = 0.0; - double[] distWeights = new double[dists.length]; - for (int i = 0; i < dists.length; i++) { - distWeights[i] = Utils.sum(dists[i]); - totalWeight += distWeights[i]; - } - double entropy = 0.0; - for (int i = 0; i < dists.length; i++) { - entropy += distWeights[i] * computeEntropy(dists[i]); - } - return entropy / totalWeight; + public static double computeEntropy(double[][] dists) { + double totalWeight = 0.0; + double[] distWeights = new double[dists.length]; + for (int i = 0; i < dists.length; i++) { + distWeights[i] = Utils.sum(dists[i]); + totalWeight += distWeights[i]; } - - public static int numSubsetsGreaterThanFrac(double[][] distributions, double minFrac) { - double totalWeight = 0.0; - double[] distSums = new double[distributions.length]; - for (int i = 0; i < distSums.length; i++) { - for (int j = 0; j < distributions[i].length; j++) { - distSums[i] += distributions[i][j]; - } - totalWeight += distSums[i]; - } - int numGreater = 0; - for (double d : distSums) { - double frac = d / totalWeight; - if (frac > minFrac) { - numGreater++; - } - } - return numGreater; + double entropy = 0.0; + for (int i = 0; i < dists.length; i++) { + entropy += distWeights[i] * computeEntropy(dists[i]); } + return entropy / totalWeight; + } - @Override - public void getDescription(StringBuilder sb, int indent) { - // TODO Auto-generated method stub + public static int numSubsetsGreaterThanFrac(double[][] distributions, double minFrac) { + double totalWeight = 0.0; + double[] distSums = new double[distributions.length]; + for (int i = 0; i < distSums.length; i++) { + for (int j = 0; j < distributions[i].length; j++) { + distSums[i] += distributions[i][j]; + } + totalWeight += distSums[i]; } - - @Override - protected void prepareForUseImpl(TaskMonitor monitor, - ObjectRepository repository) { - // TODO Auto-generated method stub + int numGreater = 0; + for (double d : distSums) { + double frac = d / totalWeight; + if (frac > minFrac) { + numGreater++; + } } + return numGreater; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, + ObjectRepository repository) { + // TODO Auto-generated method stub + } }
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java index 60e1e1c..c06754a 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/InfoGainSplitCriterionMultilabel.java @@ -26,29 +26,31 @@ import com.yahoo.labs.samoa.moa.core.Utils; * Class for computing splitting criteria using information gain with respect to * distributions of class values for Multilabel data. The split criterion is * used as a parameter on decision trees and decision stumps. - * + * * @author Richard Kirkby ([email protected]) * @author Jesse Read ([email protected]) * @version $Revision: 1 $ */ public class InfoGainSplitCriterionMultilabel extends InfoGainSplitCriterion { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - public static double computeEntropy(double[] dist) { - double entropy = 0.0; - double sum = 0.0; - for (double d : dist) { - sum += d; - } - if (sum > 0.0) { - for (double num : dist) { - double d = num / sum; - if (d > 0.0) { // TODO: how small can d be before log2 overflows? - entropy -= d * Utils.log2(d) + (1 - d) * Utils.log2(1 - d); //Extension to Multilabel - } - } + public static double computeEntropy(double[] dist) { + double entropy = 0.0; + double sum = 0.0; + for (double d : dist) { + sum += d; + } + if (sum > 0.0) { + for (double num : dist) { + double d = num / sum; + if (d > 0.0) { // TODO: how small can d be before log2 overflows? + entropy -= d * Utils.log2(d) + (1 - d) * Utils.log2(1 - d); // Extension + // to + // Multilabel } - return sum > 0.0 ? entropy : 0.0; + } } + return sum > 0.0 ? entropy : 0.0; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java index a23c93b..d7173d7 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SDRSplitCriterion.java @@ -21,13 +21,13 @@ package com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria; */ public class SDRSplitCriterion extends VarianceReductionSplitCriterion { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - public static double computeSD(double[] dist) { - int N = (int)dist[0]; - double sum = dist[1]; - double sumSq = dist[2]; - return Math.sqrt((sumSq - ((sum * sum)/N))/N); - } + public static double computeSD(double[] dist) { + int N = (int) dist[0]; + double sum = dist[1]; + double sumSq = dist[2]; + return Math.sqrt((sumSq - ((sum * sum) / N)) / N); + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java index eba390e..ce5d661 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/SplitCriterion.java @@ -23,34 +23,35 @@ package com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria; import com.yahoo.labs.samoa.moa.options.OptionHandler; /** - * Interface for computing splitting criteria. - * with respect to distributions of class values. - * The split criterion is used as a parameter on - * decision trees and decision stumps. - * The two split criteria most used are - * Information Gain and Gini. - * + * Interface for computing splitting criteria. with respect to distributions of + * class values. The split criterion is used as a parameter on decision trees + * and decision stumps. The two split criteria most used are Information Gain + * and Gini. + * * @author Richard Kirkby ([email protected]) - * @version $Revision: 7 $ + * @version $Revision: 7 $ */ public interface SplitCriterion extends OptionHandler { - /** - * Computes the merit of splitting for a given - * ditribution before the split and after it. - * - * @param preSplitDist the class distribution before the split - * @param postSplitDists the class distribution after the split - * @return value of the merit of splitting - */ - public double getMeritOfSplit(double[] preSplitDist, - double[][] postSplitDists); + /** + * Computes the merit of splitting for a given ditribution before the split + * and after it. + * + * @param preSplitDist + * the class distribution before the split + * @param postSplitDists + * the class distribution after the split + * @return value of the merit of splitting + */ + public double getMeritOfSplit(double[] preSplitDist, + double[][] postSplitDists); - /** - * Computes the range of splitting merit - * - * @param preSplitDist the class distribution before the split - * @return value of the range of splitting merit - */ - public double getRangeOfMerit(double[] preSplitDist); + /** + * Computes the range of splitting merit + * + * @param preSplitDist + * the class distribution before the split + * @return value of the range of splitting merit + */ + public double getRangeOfMerit(double[] preSplitDist); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java index c5ca348..6aa66ba 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/splitcriteria/VarianceReductionSplitCriterion.java @@ -26,74 +26,69 @@ import com.yahoo.labs.samoa.moa.tasks.TaskMonitor; public class VarianceReductionSplitCriterion extends AbstractOptionHandler implements SplitCriterion { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; -/* @Override - public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) { - - double N = preSplitDist[0]; - double SDR = computeSD(preSplitDist); + /* + * @Override public double getMeritOfSplit(double[] preSplitDist, double[][] + * postSplitDists) { + * + * double N = preSplitDist[0]; double SDR = computeSD(preSplitDist); + * + * // System.out.print("postSplitDists.length"+postSplitDists.length+"\n"); + * for(int i = 0; i < postSplitDists.length; i++) { double Ni = + * postSplitDists[i][0]; SDR -= (Ni/N)*computeSD(postSplitDists[i]); } + * + * return SDR; } + */ - // System.out.print("postSplitDists.length"+postSplitDists.length+"\n"); - for(int i = 0; i < postSplitDists.length; i++) - { - double Ni = postSplitDists[i][0]; - SDR -= (Ni/N)*computeSD(postSplitDists[i]); - } + @Override + public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) { + double SDR = 0.0; + double N = preSplitDist[0]; + int count = 0; - return SDR; - }*/ - - @Override - public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) { - double SDR=0.0; - double N = preSplitDist[0]; - int count = 0; + for (int i1 = 0; i1 < postSplitDists.length; i1++) { + double[] postSplitDist = postSplitDists[i1]; + double Ni = postSplitDist[0]; + if (Ni >= 5.0) { + count = count + 1; + } + } - for (int i1 = 0; i1 < postSplitDists.length; i1++) { - double[] postSplitDist = postSplitDists[i1]; - double Ni = postSplitDist[0]; - if (Ni >= 5.0) { - count = count + 1; - } - } - - if(count == postSplitDists.length){ - SDR = computeSD(preSplitDist); - for(int i = 0; i < postSplitDists.length; i++) - { - double Ni = postSplitDists[i][0]; - SDR -= (Ni/N)*computeSD(postSplitDists[i]); - } - } - return SDR; + if (count == postSplitDists.length) { + SDR = computeSD(preSplitDist); + for (int i = 0; i < postSplitDists.length; i++) + { + double Ni = postSplitDists[i][0]; + SDR -= (Ni / N) * computeSD(postSplitDists[i]); + } } - + return SDR; + } + @Override + public double getRangeOfMerit(double[] preSplitDist) { + return 1; + } - @Override - public double getRangeOfMerit(double[] preSplitDist) { - return 1; - } + public static double computeSD(double[] dist) { - public static double computeSD(double[] dist) { - - int N = (int)dist[0]; - double sum = dist[1]; - double sumSq = dist[2]; - // return Math.sqrt((sumSq - ((sum * sum)/N))/N); - return (sumSq - ((sum * sum)/N))/N; - } + int N = (int) dist[0]; + double sum = dist[1]; + double sumSq = dist[2]; + // return Math.sqrt((sumSq - ((sum * sum)/N))/N); + return (sumSq - ((sum * sum) / N)) / N; + } - @Override - public void getDescription(StringBuilder sb, int indent) { - // TODO Auto-generated method stub - } + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, + ObjectRepository repository) { + // TODO Auto-generated method stub + } - @Override - protected void prepareForUseImpl(TaskMonitor monitor, - ObjectRepository repository) { - // TODO Auto-generated method stub - } - } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java index 6c2c807..ca3e452 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/functions/MajorityClass.java @@ -28,57 +28,57 @@ import com.yahoo.labs.samoa.instances.Instance; /** * Majority class learner. This is the simplest classifier. - * + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public class MajorityClass extends AbstractClassifier { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - @Override - public String getPurposeString() { - return "Majority class classifier: always predicts the class that has been observed most frequently the in the training data."; - } + @Override + public String getPurposeString() { + return "Majority class classifier: always predicts the class that has been observed most frequently the in the training data."; + } - protected DoubleVector observedClassDistribution; + protected DoubleVector observedClassDistribution; - @Override - public void resetLearningImpl() { - this.observedClassDistribution = new DoubleVector(); - } + @Override + public void resetLearningImpl() { + this.observedClassDistribution = new DoubleVector(); + } - @Override - public void trainOnInstanceImpl(Instance inst) { - this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); - } + @Override + public void trainOnInstanceImpl(Instance inst) { + this.observedClassDistribution.addToValue((int) inst.classValue(), inst.weight()); + } - public double[] getVotesForInstance(Instance i) { - return this.observedClassDistribution.getArrayCopy(); - } + public double[] getVotesForInstance(Instance i) { + return this.observedClassDistribution.getArrayCopy(); + } - @Override - protected Measurement[] getModelMeasurementsImpl() { - return null; - } + @Override + protected Measurement[] getModelMeasurementsImpl() { + return null; + } - @Override - public void getModelDescription(StringBuilder out, int indent) { - StringUtils.appendIndented(out, indent, "Predicted majority "); - out.append(getClassNameString()); - out.append(" = "); - out.append(getClassLabelString(this.observedClassDistribution.maxIndex())); - StringUtils.appendNewline(out); - for (int i = 0; i < this.observedClassDistribution.numValues(); i++) { - StringUtils.appendIndented(out, indent, "Observed weight of "); - out.append(getClassLabelString(i)); - out.append(": "); - out.append(this.observedClassDistribution.getValue(i)); - StringUtils.appendNewline(out); - } + @Override + public void getModelDescription(StringBuilder out, int indent) { + StringUtils.appendIndented(out, indent, "Predicted majority "); + out.append(getClassNameString()); + out.append(" = "); + out.append(getClassLabelString(this.observedClassDistribution.maxIndex())); + StringUtils.appendNewline(out); + for (int i = 0; i < this.observedClassDistribution.numValues(); i++) { + StringUtils.appendIndented(out, indent, "Observed weight of "); + out.append(getClassLabelString(i)); + out.append(": "); + out.append(this.observedClassDistribution.getValue(i)); + StringUtils.appendNewline(out); } + } - public boolean isRandomizable() { - return false; - } + public boolean isRandomizable() { + return false; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java index d6bdcaa..6ebd3c5 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/Predicate.java @@ -24,10 +24,10 @@ import com.yahoo.labs.samoa.instances.Instance; /** * Interface for a predicate (a feature) in rules. - * + * */ public interface Predicate { - - public boolean evaluate(Instance instance); - + + public boolean evaluate(Instance instance); + } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java index 8b4c332..6633da8 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/attributeclassobservers/FIMTDDNumericAttributeClassLimitObserver.java @@ -23,104 +23,99 @@ package com.yahoo.labs.samoa.moa.classifiers.rules.core.attributeclassobservers; import com.github.javacliparser.IntOption; import com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers.FIMTDDNumericAttributeClassObserver; - public class FIMTDDNumericAttributeClassLimitObserver extends FIMTDDNumericAttributeClassObserver { - /** + /** * */ - private static final long serialVersionUID = 1L; - protected int maxNodes; - //public IntOption maxNodesOption = new IntOption("maxNodesOption", 'z', "Maximum number of nodes", 50, 0, Integer.MAX_VALUE); - - - protected int numNodes; - - public int getMaxNodes() { - return this.maxNodes; - } - - public void setMaxNodes(int maxNodes) { - this.maxNodes = maxNodes; - } - - @Override - public void observeAttributeClass(double attVal, double classVal, double weight) { - if (Double.isNaN(attVal)) { //Instance.isMissingValue(attVal) - } else { - if (this.root == null) { - //maxNodes=maxNodesOption.getValue(); - maxNodes = 50; - this.root = new FIMTDDNumericAttributeClassLimitObserver.Node(attVal, classVal, weight); - } else { - this.root.insertValue(attVal, classVal, weight); - } - } - } - - protected class Node extends FIMTDDNumericAttributeClassObserver.Node { - /** + private static final long serialVersionUID = 1L; + protected int maxNodes; + // public IntOption maxNodesOption = new IntOption("maxNodesOption", 'z', + // "Maximum number of nodes", 50, 0, Integer.MAX_VALUE); + + protected int numNodes; + + public int getMaxNodes() { + return this.maxNodes; + } + + public void setMaxNodes(int maxNodes) { + this.maxNodes = maxNodes; + } + + @Override + public void observeAttributeClass(double attVal, double classVal, double weight) { + if (Double.isNaN(attVal)) { // Instance.isMissingValue(attVal) + } else { + if (this.root == null) { + // maxNodes=maxNodesOption.getValue(); + maxNodes = 50; + this.root = new FIMTDDNumericAttributeClassLimitObserver.Node(attVal, classVal, weight); + } else { + this.root.insertValue(attVal, classVal, weight); + } + } + } + + protected class Node extends FIMTDDNumericAttributeClassObserver.Node { + /** * */ - private static final long serialVersionUID = -4484141636424708465L; - - - - public Node(double val, double label, double weight) { - super(val, label, weight); - } - - protected Node root = null; - - - - /** - * Insert a new value into the tree, updating both the sum of values and - * sum of squared values arrays - */ - @Override - public void insertValue(double val, double label, double weight) { - - // If the new value equals the value stored in a node, update - // the left (<=) node information - if (val == this.cut_point) - { - this.leftStatistics.addToValue(0,1); - this.leftStatistics.addToValue(1,label); - this.leftStatistics.addToValue(2,label*label); - } - // If the new value is less than the value in a node, update the - // left distribution and send the value down to the left child node. - // If no left child exists, create one - else if (val <= this.cut_point) { - this.leftStatistics.addToValue(0,1); - this.leftStatistics.addToValue(1,label); - this.leftStatistics.addToValue(2,label*label); - if (this.left == null) { - if(numNodes<maxNodes){ - this.left = new Node(val, label, weight); - ++numNodes; - } - } else { - this.left.insertValue(val, label, weight); - } - } - // If the new value is greater than the value in a node, update the - // right (>) distribution and send the value down to the right child node. - // If no right child exists, create one - else { // val > cut_point - this.rightStatistics.addToValue(0,1); - this.rightStatistics.addToValue(1,label); - this.rightStatistics.addToValue(2,label*label); - if (this.right == null) { - if(numNodes<maxNodes){ - this.right = new Node(val, label, weight); - ++numNodes; - } - } else { - this.right.insertValue(val, label, weight); - } - } - } - } + private static final long serialVersionUID = -4484141636424708465L; + + public Node(double val, double label, double weight) { + super(val, label, weight); + } + + protected Node root = null; + + /** + * Insert a new value into the tree, updating both the sum of values and sum + * of squared values arrays + */ + @Override + public void insertValue(double val, double label, double weight) { + + // If the new value equals the value stored in a node, update + // the left (<=) node information + if (val == this.cut_point) + { + this.leftStatistics.addToValue(0, 1); + this.leftStatistics.addToValue(1, label); + this.leftStatistics.addToValue(2, label * label); + } + // If the new value is less than the value in a node, update the + // left distribution and send the value down to the left child node. + // If no left child exists, create one + else if (val <= this.cut_point) { + this.leftStatistics.addToValue(0, 1); + this.leftStatistics.addToValue(1, label); + this.leftStatistics.addToValue(2, label * label); + if (this.left == null) { + if (numNodes < maxNodes) { + this.left = new Node(val, label, weight); + ++numNodes; + } + } else { + this.left.insertValue(val, label, weight); + } + } + // If the new value is greater than the value in a node, update the + // right (>) distribution and send the value down to the right child node. + // If no right child exists, create one + else { // val > cut_point + this.rightStatistics.addToValue(0, 1); + this.rightStatistics.addToValue(1, label); + this.rightStatistics.addToValue(2, label * label); + if (this.right == null) { + if (numNodes < maxNodes) { + this.right = new Node(val, label, weight); + ++numNodes; + } + } else { + this.right.insertValue(val, label, weight); + } + } + } + } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java index 5f27c40..94edc33 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/conditionaltests/NumericAttributeBinaryRulePredicate.java @@ -47,134 +47,135 @@ import com.yahoo.labs.samoa.moa.classifiers.rules.core.Predicate; /** * Numeric binary conditional test for instances to use to split nodes in * AMRules. - * + * * @version $Revision: 1 $ */ public class NumericAttributeBinaryRulePredicate extends InstanceConditionalBinaryTest implements Predicate { - private static final long serialVersionUID = 1L; + private static final long serialVersionUID = 1L; - protected int attIndex; + protected int attIndex; - protected double attValue; + protected double attValue; - protected int operator; // 0 =, 1<=, 2> + protected int operator; // 0 =, 1<=, 2> - public NumericAttributeBinaryRulePredicate() { - this(0,0,0); - } - public NumericAttributeBinaryRulePredicate(int attIndex, double attValue, - int operator) { - this.attIndex = attIndex; - this.attValue = attValue; - this.operator = operator; - } - - public NumericAttributeBinaryRulePredicate(NumericAttributeBinaryRulePredicate oldTest) { - this(oldTest.attIndex, oldTest.attValue, oldTest.operator); - } + public NumericAttributeBinaryRulePredicate() { + this(0, 0, 0); + } + + public NumericAttributeBinaryRulePredicate(int attIndex, double attValue, + int operator) { + this.attIndex = attIndex; + this.attValue = attValue; + this.operator = operator; + } - @Override - public int branchForInstance(Instance inst) { - int instAttIndex = this.attIndex < inst.classIndex() ? this.attIndex - : this.attIndex + 1; - if (inst.isMissing(instAttIndex)) { - return -1; - } - double v = inst.value(instAttIndex); - int ret = 0; - switch (this.operator) { - case 0: - ret = (v == this.attValue) ? 0 : 1; - break; - case 1: - ret = (v <= this.attValue) ? 0 : 1; - break; - case 2: - ret = (v > this.attValue) ? 0 : 1; - } - return ret; + public NumericAttributeBinaryRulePredicate(NumericAttributeBinaryRulePredicate oldTest) { + this(oldTest.attIndex, oldTest.attValue, oldTest.operator); + } + + @Override + public int branchForInstance(Instance inst) { + int instAttIndex = this.attIndex < inst.classIndex() ? this.attIndex + : this.attIndex + 1; + if (inst.isMissing(instAttIndex)) { + return -1; + } + double v = inst.value(instAttIndex); + int ret = 0; + switch (this.operator) { + case 0: + ret = (v == this.attValue) ? 0 : 1; + break; + case 1: + ret = (v <= this.attValue) ? 0 : 1; + break; + case 2: + ret = (v > this.attValue) ? 0 : 1; } + return ret; + } - /** + /** * */ - @Override - public String describeConditionForBranch(int branch, InstancesHeader context) { - if ((branch >= 0) && (branch <= 2)) { - String compareChar = (branch == 0) ? "=" : (branch == 1) ? "<=" : ">"; - return InstancesHeader.getAttributeNameString(context, - this.attIndex) - + ' ' - + compareChar - + InstancesHeader.getNumericValueString(context, - this.attIndex, this.attValue); - } - throw new IndexOutOfBoundsException(); + @Override + public String describeConditionForBranch(int branch, InstancesHeader context) { + if ((branch >= 0) && (branch <= 2)) { + String compareChar = (branch == 0) ? "=" : (branch == 1) ? "<=" : ">"; + return InstancesHeader.getAttributeNameString(context, + this.attIndex) + + ' ' + + compareChar + + InstancesHeader.getNumericValueString(context, + this.attIndex, this.attValue); } + throw new IndexOutOfBoundsException(); + } - /** + /** * */ - @Override - public void getDescription(StringBuilder sb, int indent) { - // TODO Auto-generated method stub - } - - @Override - public int[] getAttsTestDependsOn() { - return new int[]{this.attIndex}; - } - - public double getSplitValue() { - return this.attValue; - } - - @Override - public boolean evaluate(Instance inst) { - return (branchForInstance(inst) == 0); + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } + + @Override + public int[] getAttsTestDependsOn() { + return new int[] { this.attIndex }; + } + + public double getSplitValue() { + return this.attValue; + } + + @Override + public boolean evaluate(Instance inst) { + return (branchForInstance(inst) == 0); + } + + @Override + public String toString() { + if ((operator >= 0) && (operator <= 2)) { + String compareChar = (operator == 0) ? "=" : (operator == 1) ? "<=" : ">"; + // int equalsBranch = this.equalsPassesTest ? 0 : 1; + return "x" + this.attIndex + + ' ' + + compareChar + + ' ' + + this.attValue; } - - @Override - public String toString() { - if ((operator >= 0) && (operator <= 2)) { - String compareChar = (operator == 0) ? "=" : (operator == 1) ? "<=" : ">"; - //int equalsBranch = this.equalsPassesTest ? 0 : 1; - return "x" + this.attIndex - + ' ' - + compareChar - + ' ' - + this.attValue; - } - throw new IndexOutOfBoundsException(); + throw new IndexOutOfBoundsException(); + } + + public boolean isEqual(NumericAttributeBinaryRulePredicate predicate) { + return (this.attIndex == predicate.attIndex + && this.attValue == predicate.attValue + && this.operator == predicate.operator); + } + + public boolean isUsingSameAttribute(NumericAttributeBinaryRulePredicate predicate) { + return (this.attIndex == predicate.attIndex + && this.operator == predicate.operator); + } + + public boolean isIncludedInRuleNode( + NumericAttributeBinaryRulePredicate predicate) { + boolean ret; + if (this.operator == 1) { // <= + ret = (predicate.attValue <= this.attValue); + } else { // > + ret = (predicate.attValue > this.attValue); } - public boolean isEqual(NumericAttributeBinaryRulePredicate predicate) { - return (this.attIndex == predicate.attIndex - && this.attValue == predicate.attValue - && this.operator == predicate.operator); - } + return ret; + } - public boolean isUsingSameAttribute(NumericAttributeBinaryRulePredicate predicate) { - return (this.attIndex == predicate.attIndex - && this.operator == predicate.operator); - } + public void setAttributeValue( + NumericAttributeBinaryRulePredicate ruleSplitNodeTest) { + this.attValue = ruleSplitNodeTest.attValue; - public boolean isIncludedInRuleNode( - NumericAttributeBinaryRulePredicate predicate) { - boolean ret; - if (this.operator == 1) { // <= - ret = (predicate.attValue <= this.attValue); - } else { // > - ret = (predicate.attValue > this.attValue); - } - - return ret; - } - - public void setAttributeValue( - NumericAttributeBinaryRulePredicate ruleSplitNodeTest) { - this.attValue = ruleSplitNodeTest.attValue; - - } + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java index 58d1e2f..818d075 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/splitcriteria/SDRSplitCriterionAMRules.java @@ -43,63 +43,57 @@ package com.yahoo.labs.samoa.moa.classifiers.rules.core.splitcriteria; import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SDRSplitCriterion; import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; - public class SDRSplitCriterionAMRules extends SDRSplitCriterion implements SplitCriterion { - private static final long serialVersionUID = 1L; - - - @Override - public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) { - double SDR=0.0; - double N = preSplitDist[0]; - int count = 0; - - for(int i = 0; i < postSplitDists.length; i++) - { - double Ni = postSplitDists[i][0]; - if(Ni >=0.05*preSplitDist[0]){ - count = count +1; - } - } - if(count == postSplitDists.length){ - SDR = computeSD(preSplitDist); - - for(int i = 0; i < postSplitDists.length; i++) - { - double Ni = postSplitDists[i][0]; - SDR -= (Ni/N)*computeSD(postSplitDists[i]); - - } - } - return SDR; - } - - - - @Override - public double getRangeOfMerit(double[] preSplitDist) { - return 1; - } - - public static double[] computeBranchSplitMerits(double[][] postSplitDists) { - double[] SDR = new double[postSplitDists.length]; - double N = 0; - - for(int i = 0; i < postSplitDists.length; i++) - { - double Ni = postSplitDists[i][0]; - N += Ni; - } - for(int i = 0; i < postSplitDists.length; i++) - { - double Ni = postSplitDists[i][0]; - SDR[i] = (Ni/N)*computeSD(postSplitDists[i]); - } - return SDR; - - } - + private static final long serialVersionUID = 1L; + + @Override + public double getMeritOfSplit(double[] preSplitDist, double[][] postSplitDists) { + double SDR = 0.0; + double N = preSplitDist[0]; + int count = 0; + + for (int i = 0; i < postSplitDists.length; i++) + { + double Ni = postSplitDists[i][0]; + if (Ni >= 0.05 * preSplitDist[0]) { + count = count + 1; + } + } + if (count == postSplitDists.length) { + SDR = computeSD(preSplitDist); + + for (int i = 0; i < postSplitDists.length; i++) + { + double Ni = postSplitDists[i][0]; + SDR -= (Ni / N) * computeSD(postSplitDists[i]); + + } + } + return SDR; + } + + @Override + public double getRangeOfMerit(double[] preSplitDist) { + return 1; + } + + public static double[] computeBranchSplitMerits(double[][] postSplitDists) { + double[] SDR = new double[postSplitDists.length]; + double N = 0; + + for (int i = 0; i < postSplitDists.length; i++) + { + double Ni = postSplitDists[i][0]; + N += Ni; + } + for (int i = 0; i < postSplitDists.length; i++) + { + double Ni = postSplitDists[i][0]; + SDR[i] = (Ni / N) * computeSD(postSplitDists[i]); + } + return SDR; + + } } - http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java index 4e93b03..dcd975d 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/AbstractErrorWeightedVote.java @@ -26,78 +26,77 @@ import java.util.List; import com.yahoo.labs.samoa.moa.AbstractMOAObject; /** - * AbstractErrorWeightedVote class for weighted votes based on estimates of errors. - * + * AbstractErrorWeightedVote class for weighted votes based on estimates of + * errors. + * * @author Joao Duarte ([email protected]) * @version $Revision: 1 $ */ -public abstract class AbstractErrorWeightedVote extends AbstractMOAObject implements ErrorWeightedVote{ - /** +public abstract class AbstractErrorWeightedVote extends AbstractMOAObject implements ErrorWeightedVote { + /** * */ - private static final long serialVersionUID = -7340491298217227675L; - protected List<double[]> votes; - protected List<Double> errors; - protected double[] weights; - - + private static final long serialVersionUID = -7340491298217227675L; + protected List<double[]> votes; + protected List<Double> errors; + protected double[] weights; - public AbstractErrorWeightedVote() { - super(); - votes = new ArrayList<double[]>(); - errors = new ArrayList<Double>(); - } - - public AbstractErrorWeightedVote(AbstractErrorWeightedVote aewv) { - super(); - votes = new ArrayList<double[]>(); - for (double[] vote:aewv.votes) { - double[] v = new double[vote.length]; - for (int i=0; i<vote.length; i++) v[i] = vote[i]; - votes.add(v); - } - errors = new ArrayList<Double>(); - for (Double db:aewv.errors) { - errors.add(db.doubleValue()); - } - if (aewv.weights != null) { - weights = new double[aewv.weights.length]; - for (int i = 0; i<aewv.weights.length; i++) - weights[i] = aewv.weights[i]; - } - } + public AbstractErrorWeightedVote() { + super(); + votes = new ArrayList<double[]>(); + errors = new ArrayList<Double>(); + } + public AbstractErrorWeightedVote(AbstractErrorWeightedVote aewv) { + super(); + votes = new ArrayList<double[]>(); + for (double[] vote : aewv.votes) { + double[] v = new double[vote.length]; + for (int i = 0; i < vote.length; i++) + v[i] = vote[i]; + votes.add(v); + } + errors = new ArrayList<Double>(); + for (Double db : aewv.errors) { + errors.add(db.doubleValue()); + } + if (aewv.weights != null) { + weights = new double[aewv.weights.length]; + for (int i = 0; i < aewv.weights.length; i++) + weights[i] = aewv.weights[i]; + } + } - @Override - public void addVote(double [] vote, double error) { - votes.add(vote); - errors.add(error); - } + @Override + public void addVote(double[] vote, double error) { + votes.add(vote); + errors.add(error); + } - @Override - abstract public double[] computeWeightedVote(); + @Override + abstract public double[] computeWeightedVote(); - @Override - public double getWeightedError() - { - double weightedError=0; - if (weights!=null && weights.length==errors.size()) - { - for (int i=0; i<weights.length; ++i) - weightedError+=errors.get(i)*weights[i]; - } - else - weightedError=-1; - return weightedError; - } + @Override + public double getWeightedError() + { + double weightedError = 0; + if (weights != null && weights.length == errors.size()) + { + for (int i = 0; i < weights.length; ++i) + weightedError += errors.get(i) * weights[i]; + } + else + weightedError = -1; + return weightedError; + } - @Override - public double [] getWeights() { - return weights; - } + @Override + public double[] getWeights() { + return weights; + } - @Override - public int getNumberVotes() { - return votes.size(); - } + @Override + public int getNumberVotes() { + return votes.size(); + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java index 943dd9d..fca7115 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/ErrorWeightedVote.java @@ -23,59 +23,60 @@ package com.yahoo.labs.samoa.moa.classifiers.rules.core.voting; import com.yahoo.labs.samoa.moa.MOAObject; /** - * ErrorWeightedVote interface for weighted votes based on estimates of errors. - * + * ErrorWeightedVote interface for weighted votes based on estimates of errors. + * * @author Joao Duarte ([email protected]) * @version $Revision: 1 $ */ public interface ErrorWeightedVote { - - /** - * Adds a vote and the corresponding error for the computation of the weighted vote and respective weighted error. - * - * @param vote a vote returned by a classifier - * @param error the error associated to the vote - */ - public void addVote(double [] vote, double error); - - /** - * Computes the weighted vote. - * Also updates the weights of the votes. - * - * @return the weighted vote - */ - public double [] computeWeightedVote(); - - /** - * Returns the weighted error. - * - * @pre computeWeightedVote() - * @return the weighted error - */ - public double getWeightedError(); - - /** - * Return the weights error. - * - * @pre computeWeightedVote() - * @return the weights - */ - public double [] getWeights(); - - - /** - * The number of votes added so far. - * - * @return the number of votes - */ - public int getNumberVotes(); - - /** - * Creates a copy of the object - * - * @return copy of the object - */ - public MOAObject copy(); - - public ErrorWeightedVote getACopy(); + + /** + * Adds a vote and the corresponding error for the computation of the weighted + * vote and respective weighted error. + * + * @param vote + * a vote returned by a classifier + * @param error + * the error associated to the vote + */ + public void addVote(double[] vote, double error); + + /** + * Computes the weighted vote. Also updates the weights of the votes. + * + * @return the weighted vote + */ + public double[] computeWeightedVote(); + + /** + * Returns the weighted error. + * + * @pre computeWeightedVote() + * @return the weighted error + */ + public double getWeightedError(); + + /** + * Return the weights error. + * + * @pre computeWeightedVote() + * @return the weights + */ + public double[] getWeights(); + + /** + * The number of votes added so far. + * + * @return the number of votes + */ + public int getNumberVotes(); + + /** + * Creates a copy of the object + * + * @return copy of the object + */ + public MOAObject copy(); + + public ErrorWeightedVote getACopy(); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java index 401dc58..2e2cb56 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/InverseErrorWeightedVote.java @@ -21,79 +21,81 @@ package com.yahoo.labs.samoa.moa.classifiers.rules.core.voting; */ /** - * InverseErrorWeightedVote class for weighted votes based on estimates of errors. - * + * InverseErrorWeightedVote class for weighted votes based on estimates of + * errors. + * * @author Joao Duarte ([email protected]) * @version $Revision: 1 $ */ public class InverseErrorWeightedVote extends AbstractErrorWeightedVote { - /** + /** * */ - private static final double EPS = 0.000000001; //just to prevent divide by 0 in 1/X -> 1/(x+EPS) - private static final long serialVersionUID = 6359349250620616482L; + private static final double EPS = 0.000000001; // just to prevent divide by 0 + // in 1/X -> 1/(x+EPS) + private static final long serialVersionUID = 6359349250620616482L; + + public InverseErrorWeightedVote() { + super(); + } + + public InverseErrorWeightedVote(AbstractErrorWeightedVote aewv) { + super(aewv); + } + + @Override + public double[] computeWeightedVote() { + int n = votes.size(); + weights = new double[n]; + double[] weightedVote = null; + if (n > 0) { + int d = votes.get(0).length; + weightedVote = new double[d]; + double sumError = 0; + // weights are 1/(error+eps) + for (int i = 0; i < n; ++i) { + if (errors.get(i) < Double.MAX_VALUE) { + weights[i] = 1.0 / (errors.get(i) + EPS); + sumError += weights[i]; + } + else + weights[i] = 0; - public InverseErrorWeightedVote() { - super(); - } - - public InverseErrorWeightedVote(AbstractErrorWeightedVote aewv) { - super(aewv); - } - - @Override - public double[] computeWeightedVote() { - int n=votes.size(); - weights=new double[n]; - double [] weightedVote=null; - if (n>0){ - int d=votes.get(0).length; - weightedVote=new double[d]; - double sumError=0; - //weights are 1/(error+eps) - for (int i=0; i<n; ++i){ - if(errors.get(i)<Double.MAX_VALUE){ - weights[i]=1.0/(errors.get(i)+EPS); - sumError+=weights[i]; - } - else - weights[i]=0; + } - } + if (sumError > 0) + for (int i = 0; i < n; ++i) + { + // normalize so that weights sum 1 + weights[i] /= sumError; + // compute weighted vote + for (int j = 0; j < d; j++) + weightedVote[j] += votes.get(i)[j] * weights[i]; + } + // Only occurs if all errors=Double.MAX_VALUE + else + { + // compute arithmetic vote + for (int i = 0; i < n; ++i) + { + for (int j = 0; j < d; j++) + weightedVote[j] += votes.get(i)[j] / n; + } + } + } + return weightedVote; + } - if(sumError>0) - for (int i=0; i<n; ++i) - { - //normalize so that weights sum 1 - weights[i]/=sumError; - //compute weighted vote - for(int j=0; j<d; j++) - weightedVote[j]+=votes.get(i)[j]*weights[i]; - } - //Only occurs if all errors=Double.MAX_VALUE - else - { - //compute arithmetic vote - for (int i=0; i<n; ++i) - { - for(int j=0; j<d; j++) - weightedVote[j]+=votes.get(i)[j]/n; - } - } - } - return weightedVote; - } + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub - @Override - public void getDescription(StringBuilder sb, int indent) { - // TODO Auto-generated method stub + } - } - - @Override - public InverseErrorWeightedVote getACopy() { - return new InverseErrorWeightedVote(this); - } + @Override + public InverseErrorWeightedVote getACopy() { + return new InverseErrorWeightedVote(this); + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java index ce7a74f..61c5b12 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/core/voting/UniformWeightedVote.java @@ -20,54 +20,52 @@ package com.yahoo.labs.samoa.moa.classifiers.rules.core.voting; * #L% */ - /** - * UniformWeightedVote class for weighted votes based on estimates of errors. - * + * UniformWeightedVote class for weighted votes based on estimates of errors. + * * @author Joao Duarte ([email protected]) * @version $Revision: 1 $ */ public class UniformWeightedVote extends AbstractErrorWeightedVote { + private static final long serialVersionUID = 6359349250620616482L; + + public UniformWeightedVote() { + super(); + } + + public UniformWeightedVote(AbstractErrorWeightedVote aewv) { + super(aewv); + } - private static final long serialVersionUID = 6359349250620616482L; + @Override + public double[] computeWeightedVote() { + int n = votes.size(); + weights = new double[n]; + double[] weightedVote = null; + if (n > 0) { + int d = votes.get(0).length; + weightedVote = new double[d]; + for (int i = 0; i < n; i++) + { + weights[i] = 1.0 / n; + for (int j = 0; j < d; j++) + weightedVote[j] += (votes.get(i)[j] * weights[i]); + } - public UniformWeightedVote() { - super(); - } - - public UniformWeightedVote(AbstractErrorWeightedVote aewv) { - super(aewv); - } - - @Override - public double[] computeWeightedVote() { - int n=votes.size(); - weights=new double[n]; - double [] weightedVote=null; - if (n>0){ - int d=votes.get(0).length; - weightedVote=new double[d]; - for (int i=0; i<n; i++) - { - weights[i]=1.0/n; - for(int j=0; j<d; j++) - weightedVote[j]+=(votes.get(i)[j]*weights[i]); - } + } + return weightedVote; + } - } - return weightedVote; - } + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub - @Override - public void getDescription(StringBuilder sb, int indent) { - // TODO Auto-generated method stub + } - } - - @Override - public UniformWeightedVote getACopy() { - return new UniformWeightedVote(this); - } + @Override + public UniformWeightedVote getACopy() { + return new UniformWeightedVote(this); + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java index 133c755..d18134b 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyFading.java @@ -22,62 +22,67 @@ package com.yahoo.labs.samoa.moa.classifiers.rules.driftdetection; /** * Page-Hinkley Test with more weight for recent instances. - * + * */ public class PageHinkleyFading extends PageHinkleyTest { - /** + /** * */ - private static final long serialVersionUID = 7110953184708812339L; - private double fadingFactor=0.99; + private static final long serialVersionUID = 7110953184708812339L; + private double fadingFactor = 0.99; + + public PageHinkleyFading() { + super(); + } + + public PageHinkleyFading(double threshold, double alpha) { + super(threshold, alpha); + } + + protected double instancesSeen; + + @Override + public void reset() { - public PageHinkleyFading() { - super(); - } - - public PageHinkleyFading(double threshold, double alpha) { - super(threshold, alpha); - } - protected double instancesSeen; + super.reset(); + this.instancesSeen = 0; - @Override - public void reset() { + } - super.reset(); - this.instancesSeen=0; + @Override + public boolean update(double error) { + this.instancesSeen = 1 + fadingFactor * this.instancesSeen; + double absolutError = Math.abs(error); - } + this.sumAbsolutError = fadingFactor * this.sumAbsolutError + absolutError; + if (this.instancesSeen > 30) { + double mT = absolutError - (this.sumAbsolutError / this.instancesSeen) - this.alpha; + this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT + // sum + if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT + // value if the new mT is + // smaller than the current + // minimum + this.minimumValue = this.cumulativeSum; + } + return (((this.cumulativeSum - this.minimumValue) > this.threshold)); + } + return false; + } - @Override - public boolean update(double error) { - this.instancesSeen=1+fadingFactor*this.instancesSeen; - double absolutError = Math.abs(error); + @Override + public PageHinkleyTest getACopy() { + PageHinkleyFading newTest = new PageHinkleyFading(this.threshold, this.alpha); + this.copyFields(newTest); + return newTest; + } - this.sumAbsolutError = fadingFactor*this.sumAbsolutError + absolutError; - if (this.instancesSeen > 30) { - double mT = absolutError - (this.sumAbsolutError / this.instancesSeen) - this.alpha; - this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT sum - if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT value if the new mT is smaller than the current minimum - this.minimumValue = this.cumulativeSum; - } - return (((this.cumulativeSum - this.minimumValue) > this.threshold)); - } - return false; - } - - @Override - public PageHinkleyTest getACopy() { - PageHinkleyFading newTest = new PageHinkleyFading(this.threshold, this.alpha); - this.copyFields(newTest); - return newTest; - } - - @Override - protected void copyFields(PageHinkleyTest newTest) { - super.copyFields(newTest); - PageHinkleyFading newFading = (PageHinkleyFading) newTest; - newFading.fadingFactor = this.fadingFactor; - newFading.instancesSeen = this.instancesSeen; - } + @Override + protected void copyFields(PageHinkleyTest newTest) { + super.copyFields(newTest); + PageHinkleyFading newFading = (PageHinkleyFading) newTest; + newFading.fadingFactor = this.fadingFactor; + newFading.instancesSeen = this.instancesSeen; + } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java index c313224..354b9a8 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/rules/driftdetection/PageHinkleyTest.java @@ -24,73 +24,76 @@ import java.io.Serializable; /** * Page-Hinkley Test with equal weights for all instances. - * + * */ public class PageHinkleyTest implements Serializable { - private static final long serialVersionUID = 1L; - protected double cumulativeSum; + private static final long serialVersionUID = 1L; + protected double cumulativeSum; + + public double getCumulativeSum() { + return cumulativeSum; + } + + public double getMinimumValue() { + return minimumValue; + } - public double getCumulativeSum() { - return cumulativeSum; - } + protected double minimumValue; + protected double sumAbsolutError; + protected long phinstancesSeen; + protected double threshold; + protected double alpha; - public double getMinimumValue() { - return minimumValue; - } + public PageHinkleyTest() { + this(0, 0); + } + public PageHinkleyTest(double threshold, double alpha) { + this.threshold = threshold; + this.alpha = alpha; + this.reset(); + } - protected double minimumValue; - protected double sumAbsolutError; - protected long phinstancesSeen; - protected double threshold; - protected double alpha; + public void reset() { + this.cumulativeSum = 0.0; + this.minimumValue = Double.MAX_VALUE; + this.sumAbsolutError = 0.0; + this.phinstancesSeen = 0; + } - public PageHinkleyTest() { - this(0,0); - } - - public PageHinkleyTest(double threshold, double alpha) { - this.threshold = threshold; - this.alpha = alpha; - this.reset(); - } + // Compute Page-Hinkley test + public boolean update(double error) { - public void reset() { - this.cumulativeSum = 0.0; - this.minimumValue = Double.MAX_VALUE; - this.sumAbsolutError = 0.0; - this.phinstancesSeen = 0; - } + this.phinstancesSeen++; + double absolutError = Math.abs(error); + this.sumAbsolutError = this.sumAbsolutError + absolutError; + if (this.phinstancesSeen > 30) { + double mT = absolutError - (this.sumAbsolutError / this.phinstancesSeen) - this.alpha; + this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT + // sum + if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT + // value if the new mT is + // smaller than the current + // minimum + this.minimumValue = this.cumulativeSum; + } + return (((this.cumulativeSum - this.minimumValue) > this.threshold)); + } + return false; + } - //Compute Page-Hinkley test - public boolean update(double error) { + public PageHinkleyTest getACopy() { + PageHinkleyTest newTest = new PageHinkleyTest(this.threshold, this.alpha); + this.copyFields(newTest); + return newTest; + } - this.phinstancesSeen++; - double absolutError = Math.abs(error); - this.sumAbsolutError = this.sumAbsolutError + absolutError; - if (this.phinstancesSeen > 30) { - double mT = absolutError - (this.sumAbsolutError / this.phinstancesSeen) - this.alpha; - this.cumulativeSum = this.cumulativeSum + mT; // Update the cumulative mT sum - if (this.cumulativeSum < this.minimumValue) { // Update the minimum mT value if the new mT is smaller than the current minimum - this.minimumValue = this.cumulativeSum; - } - return (((this.cumulativeSum - this.minimumValue) > this.threshold)); - } - return false; - } - - public PageHinkleyTest getACopy() { - PageHinkleyTest newTest = new PageHinkleyTest(this.threshold, this.alpha); - this.copyFields(newTest); - return newTest; - } - - protected void copyFields(PageHinkleyTest newTest) { - newTest.cumulativeSum = this.cumulativeSum; - newTest.minimumValue = this.minimumValue; - newTest.sumAbsolutError = this.sumAbsolutError; - newTest.phinstancesSeen = this.phinstancesSeen; - } + protected void copyFields(PageHinkleyTest newTest) { + newTest.cumulativeSum = this.cumulativeSum; + newTest.minimumValue = this.minimumValue; + newTest.sumAbsolutError = this.sumAbsolutError; + newTest.phinstancesSeen = this.phinstancesSeen; + } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java index 47e1379..b8f22c3 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/CFCluster.java @@ -1,4 +1,3 @@ - package com.yahoo.labs.samoa.moa.cluster; /* @@ -26,148 +25,153 @@ import com.yahoo.labs.samoa.instances.Instance; /* micro cluster, as defined by Aggarwal et al, On Clustering Massive Data Streams: A Summarization Praradigm * in the book Data streams : models and algorithms, by Charu C Aggarwal * @article{ - title = {Data Streams: Models and Algorithms}, - author = {Aggarwal, Charu C.}, - year = {2007}, - publisher = {Springer Science+Business Media, LLC}, - url = {http://ebooks.ulb.tu-darmstadt.de/11157/}, - institution = {eBooks [http://ebooks.ulb.tu-darmstadt.de/perl/oai2] (Germany)}, -} + title = {Data Streams: Models and Algorithms}, + author = {Aggarwal, Charu C.}, + year = {2007}, + publisher = {Springer Science+Business Media, LLC}, + url = {http://ebooks.ulb.tu-darmstadt.de/11157/}, + institution = {eBooks [http://ebooks.ulb.tu-darmstadt.de/perl/oai2] (Germany)}, + } -DEFINITION A micro-clusterfor a set of d-dimensionalpoints Xi,. .Xi, -with t i m e s t a m p s ~. . .T,, is the (2-d+3)tuple (CF2", CFlX CF2t, CFlt, n), -wherein CF2" and CFlX each correspond to a vector of d entries. The definition of each of these entries is as follows: + DEFINITION A micro-clusterfor a set of d-dimensionalpoints Xi,. .Xi, + with t i m e s t a m p s ~. . .T,, is the (2-d+3)tuple (CF2", CFlX CF2t, CFlt, n), + wherein CF2" and CFlX each correspond to a vector of d entries. The definition of each of these entries is as follows: -o For each dimension, the sum of the squares of the data values is maintained -in CF2". Thus, CF2" contains d values. The p-th entry of CF2" is equal to -\sum_j=1^n(x_i_j)^2 + o For each dimension, the sum of the squares of the data values is maintained + in CF2". Thus, CF2" contains d values. The p-th entry of CF2" is equal to + \sum_j=1^n(x_i_j)^2 -o For each dimension, the sum of the data values is maintained in C F l X . -Thus, CFIX contains d values. The p-th entry of CFIX is equal to -\sum_j=1^n x_i_j + o For each dimension, the sum of the data values is maintained in C F l X . + Thus, CFIX contains d values. The p-th entry of CFIX is equal to + \sum_j=1^n x_i_j -o The sum of the squares of the time stamps Ti,. .Tin maintained in CF2t + o The sum of the squares of the time stamps Ti,. .Tin maintained in CF2t -o The sum of the time stamps Ti, . . .Tin maintained in CFlt. + o The sum of the time stamps Ti, . . .Tin maintained in CFlt. -o The number of data points is maintained in n. + o The number of data points is maintained in n. */ public abstract class CFCluster extends SphereCluster { - private static final long serialVersionUID = 1L; - - protected double radiusFactor = 1.8; - - /** - * Number of points in the cluster. - */ - protected double N; - /** - * Linear sum of all the points added to the cluster. - */ - public double[] LS; - /** - * Squared sum of all the points added to the cluster. - */ - public double[] SS; - - /** - * Instantiates an empty kernel with the given dimensionality. - * @param dimensions The number of dimensions of the points that can be in - * this kernel. - */ - public CFCluster(Instance instance, int dimensions) { - this(instance.toDoubleArray(), dimensions); - } - - protected CFCluster(int dimensions) { - this.N = 0; - this.LS = new double[dimensions]; - this.SS = new double[dimensions]; - Arrays.fill(this.LS, 0.0); - Arrays.fill(this.SS, 0.0); - } - - public CFCluster(double [] center, int dimensions) { - this.N = 1; - this.LS = center; - this.SS = new double[dimensions]; - for (int i = 0; i < SS.length; i++) { - SS[i]=Math.pow(center[i], 2); - } - } - - public CFCluster(CFCluster cluster) { - this.N = cluster.N; - this.LS = Arrays.copyOf(cluster.LS, cluster.LS.length); - this.SS = Arrays.copyOf(cluster.SS, cluster.SS.length); - } - - public void add(CFCluster cluster ) { - this.N += cluster.N; - addVectors( this.LS, cluster.LS ); - addVectors( this.SS, cluster.SS ); - } - - public abstract CFCluster getCF(); - - /** - * @return this kernels' center - */ - @Override - public double[] getCenter() { - assert (this.N>0); - double res[] = new double[this.LS.length]; - for ( int i = 0; i < res.length; i++ ) { - res[i] = this.LS[i] / N; - } - return res; - } - - - @Override - public abstract double getInclusionProbability(Instance instance); - - /** - * See interface <code>Cluster</code> - * @return The radius of the cluster. - */ - @Override - public abstract double getRadius(); - - /** - * See interface <code>Cluster</code> - * @return The weight. - * @see Cluster#getWeight() - */ - @Override - public double getWeight() { - return N; - } - - public void setN(double N){ - this.N = N; - } - - public double getN() { - return N; - } - - /** - * Adds the second array to the first array element by element. The arrays - * must have the same length. - * @param a1 Vector to which the second vector is added. - * @param a2 Vector to be added. This vector does not change. - */ - public static void addVectors(double[] a1, double[] a2) { - assert (a1 != null); - assert (a2 != null); - assert (a1.length == a2.length) : "Adding two arrays of different " - + "length"; - - for (int i = 0; i < a1.length; i++) { - a1[i] += a2[i]; - } - } + private static final long serialVersionUID = 1L; + + protected double radiusFactor = 1.8; + + /** + * Number of points in the cluster. + */ + protected double N; + /** + * Linear sum of all the points added to the cluster. + */ + public double[] LS; + /** + * Squared sum of all the points added to the cluster. + */ + public double[] SS; + + /** + * Instantiates an empty kernel with the given dimensionality. + * + * @param dimensions + * The number of dimensions of the points that can be in this kernel. + */ + public CFCluster(Instance instance, int dimensions) { + this(instance.toDoubleArray(), dimensions); + } + + protected CFCluster(int dimensions) { + this.N = 0; + this.LS = new double[dimensions]; + this.SS = new double[dimensions]; + Arrays.fill(this.LS, 0.0); + Arrays.fill(this.SS, 0.0); + } + + public CFCluster(double[] center, int dimensions) { + this.N = 1; + this.LS = center; + this.SS = new double[dimensions]; + for (int i = 0; i < SS.length; i++) { + SS[i] = Math.pow(center[i], 2); + } + } + + public CFCluster(CFCluster cluster) { + this.N = cluster.N; + this.LS = Arrays.copyOf(cluster.LS, cluster.LS.length); + this.SS = Arrays.copyOf(cluster.SS, cluster.SS.length); + } + + public void add(CFCluster cluster) { + this.N += cluster.N; + addVectors(this.LS, cluster.LS); + addVectors(this.SS, cluster.SS); + } + + public abstract CFCluster getCF(); + + /** + * @return this kernels' center + */ + @Override + public double[] getCenter() { + assert (this.N > 0); + double res[] = new double[this.LS.length]; + for (int i = 0; i < res.length; i++) { + res[i] = this.LS[i] / N; + } + return res; + } + + @Override + public abstract double getInclusionProbability(Instance instance); + + /** + * See interface <code>Cluster</code> + * + * @return The radius of the cluster. + */ + @Override + public abstract double getRadius(); + + /** + * See interface <code>Cluster</code> + * + * @return The weight. + * @see Cluster#getWeight() + */ + @Override + public double getWeight() { + return N; + } + + public void setN(double N) { + this.N = N; + } + + public double getN() { + return N; + } + + /** + * Adds the second array to the first array element by element. The arrays + * must have the same length. + * + * @param a1 + * Vector to which the second vector is added. + * @param a2 + * Vector to be added. This vector does not change. + */ + public static void addVectors(double[] a1, double[] a2) { + assert (a1 != null); + assert (a2 != null); + assert (a1.length == a2.length) : "Adding two arrays of different " + + "length"; + + for (int i = 0; i < a1.length; i++) { + a1[i] += a2[i]; + } + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java index a9380a8..82cdb3b 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/cluster/Cluster.java @@ -1,4 +1,3 @@ - package com.yahoo.labs.samoa.moa.cluster; /* @@ -31,142 +30,139 @@ import com.yahoo.labs.samoa.moa.AbstractMOAObject; public abstract class Cluster extends AbstractMOAObject { - private static final long serialVersionUID = 1L; - - private double id = -1; - private double gtLabel = -1; - - private Map<String, String> measure_values; - - - public Cluster(){ - this.measure_values = new HashMap<>(); + private static final long serialVersionUID = 1L; + + private double id = -1; + private double gtLabel = -1; + + private Map<String, String> measure_values; + + public Cluster() { + this.measure_values = new HashMap<>(); + } + + /** + * @return the center of the cluster + */ + public abstract double[] getCenter(); + + /** + * Returns the weight of this cluster, not neccessarily normalized. It could, + * for instance, simply return the number of points contined in this cluster. + * + * @return the weight + */ + public abstract double getWeight(); + + /** + * Returns the probability of the given point belonging to this cluster. + * + * @param instance + * @return a value between 0 and 1 + */ + public abstract double getInclusionProbability(Instance instance); + + // TODO: for non sphere cluster sample points, find out MIN MAX neighbours + // within cluster + // and return the relative distance + // public abstract double getRelativeHullDistance(Instance instance); + + @Override + public void getDescription(StringBuilder sb, int i) { + sb.append("Cluster Object"); + } + + public void setId(double id) { + this.id = id; + } + + public double getId() { + return id; + } + + public boolean isGroundTruth() { + return gtLabel != -1; + } + + public void setGroundTruth(double truth) { + gtLabel = truth; + } + + public double getGroundTruth() { + return gtLabel; + } + + /** + * Samples this cluster by returning a point from inside it. + * + * @param random + * a random number source + * @return an Instance that lies inside this cluster + */ + public abstract Instance sample(Random random); + + public void setMeasureValue(String measureKey, String value) { + measure_values.put(measureKey, value); + } + + public void setMeasureValue(String measureKey, double value) { + measure_values.put(measureKey, Double.toString(value)); + } + + public String getMeasureValue(String measureKey) { + if (measure_values.containsKey(measureKey)) + return measure_values.get(measureKey); + else + return ""; + } + + protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue) { + infoTitle.add("ClusterID"); + infoValue.add(Integer.toString((int) getId())); + + infoTitle.add("Type"); + infoValue.add(getClass().getSimpleName()); + + double c[] = getCenter(); + if (c != null) + for (int i = 0; i < c.length; i++) { + infoTitle.add("Dim" + i); + infoValue.add(Double.toString(c[i])); + } + + infoTitle.add("Weight"); + infoValue.add(Double.toString(getWeight())); + + } + + public String getInfo() { + List<String> infoTitle = new ArrayList<>(); + List<String> infoValue = new ArrayList<>(); + getClusterSpecificInfo(infoTitle, infoValue); + + StringBuilder sb = new StringBuilder(); + + // Cluster properties + sb.append("<html>"); + sb.append("<table>"); + int i = 0; + while (i < infoTitle.size() && i < infoValue.size()) { + sb.append("<tr><td>" + infoTitle.get(i) + "</td><td>" + infoValue.get(i) + "</td></tr>"); + i++; } - /** - * @return the center of the cluster - */ - public abstract double[] getCenter(); - - /** - * Returns the weight of this cluster, not neccessarily normalized. - * It could, for instance, simply return the number of points contined - * in this cluster. - * @return the weight - */ - public abstract double getWeight(); - - /** - * Returns the probability of the given point belonging to - * this cluster. - * - * @param instance - * @return a value between 0 and 1 - */ - public abstract double getInclusionProbability(Instance instance); - - - //TODO: for non sphere cluster sample points, find out MIN MAX neighbours within cluster - //and return the relative distance - //public abstract double getRelativeHullDistance(Instance instance); - - @Override - public void getDescription(StringBuilder sb, int i) { - sb.append("Cluster Object"); - } - - public void setId(double id) { - this.id = id; - } - - public double getId() { - return id; - } - - public boolean isGroundTruth(){ - return gtLabel != -1; - } - - public void setGroundTruth(double truth){ - gtLabel = truth; - } - - public double getGroundTruth(){ - return gtLabel; - } - - - /** - * Samples this cluster by returning a point from inside it. - * @param random a random number source - * @return an Instance that lies inside this cluster - */ - public abstract Instance sample(Random random); - - - public void setMeasureValue(String measureKey, String value){ - measure_values.put(measureKey, value); - } - - public void setMeasureValue(String measureKey, double value){ - measure_values.put(measureKey, Double.toString(value)); - } - - - public String getMeasureValue(String measureKey){ - if(measure_values.containsKey(measureKey)) - return measure_values.get(measureKey); - else - return ""; - } - - - protected void getClusterSpecificInfo(List<String> infoTitle, List<String> infoValue){ - infoTitle.add("ClusterID"); - infoValue.add(Integer.toString((int)getId())); - - infoTitle.add("Type"); - infoValue.add(getClass().getSimpleName()); - - double c[] = getCenter(); - if(c!=null) - for (int i = 0; i < c.length; i++) { - infoTitle.add("Dim"+i); - infoValue.add(Double.toString(c[i])); - } - - infoTitle.add("Weight"); - infoValue.add(Double.toString(getWeight())); - - } - - public String getInfo() { - List<String> infoTitle = new ArrayList<>(); - List<String> infoValue = new ArrayList<>(); - getClusterSpecificInfo(infoTitle, infoValue); - - StringBuilder sb = new StringBuilder(); - - //Cluster properties - sb.append("<html>"); - sb.append("<table>"); - int i = 0; - while(i < infoTitle.size() && i < infoValue.size()){ - sb.append("<tr><td>"+infoTitle.get(i)+"</td><td>"+infoValue.get(i)+"</td></tr>"); - i++; - } - sb.append("</table>"); - - //Evaluation info - sb.append("<br>"); - sb.append("<b>Evaluation</b><br>"); - sb.append("<table>"); - for (Object o : measure_values.entrySet()) { - Map.Entry e = (Map.Entry) o; - sb.append("<tr><td>" + e.getKey() + "</td><td>" + e.getValue() + "</td></tr>"); - } - sb.append("</table>"); - sb.append("</html>"); - return sb.toString(); + sb.append("</table>"); + + // Evaluation info + sb.append("<br>"); + sb.append("<b>Evaluation</b><br>"); + sb.append("<table>"); + for (Object o : measure_values.entrySet()) { + Map.Entry e = (Map.Entry) o; + sb.append("<tr><td>" + e.getKey() + "</td><td>" + e.getValue() + "</td></tr>"); } + sb.append("</table>"); + sb.append("</html>"); + return sb.toString(); + } }
