http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java new file mode 100644 index 0000000..ec48156 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/BasicRegressionPerformanceEvaluator.java @@ -0,0 +1,134 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.AbstractMOAObject; +import org.apache.samoa.moa.core.Measurement; + +/** + * Regression evaluator that performs basic incremental evaluation. + * + * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) + * @version $Revision: 7 $ + */ +public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject + implements RegressionPerformanceEvaluator { + + private static final long serialVersionUID = 1L; + + protected double weightObserved; + + protected double squareError; + + protected double averageError; + + protected double sumTarget; + + protected double squareTargetError; + + protected double averageTargetError; + + @Override + public void reset() { + this.weightObserved = 0.0; + this.squareError = 0.0; + this.averageError = 0.0; + this.sumTarget = 0.0; + this.averageTargetError = 0.0; + this.squareTargetError = 0.0; + + } + + @Override + public void addResult(Instance inst, double[] prediction) { + double weight = inst.weight(); + double classValue = inst.classValue(); + if (weight > 0.0) { + if (prediction.length > 0) { + double meanTarget = this.weightObserved != 0 ? + this.sumTarget / this.weightObserved : 0.0; + this.squareError += (classValue - prediction[0]) * (classValue - prediction[0]); + this.averageError += Math.abs(classValue - prediction[0]); + this.squareTargetError += (classValue - meanTarget) * (classValue - meanTarget); + this.averageTargetError += Math.abs(classValue - meanTarget); + this.sumTarget += classValue; + this.weightObserved += weight; + } + } + } + + @Override + public Measurement[] getPerformanceMeasurements() { + return new Measurement[] { + new Measurement("classified instances", + getTotalWeightObserved()), + new Measurement("mean absolute error", + getMeanError()), + new Measurement("root mean squared error", + getSquareError()), + new Measurement("relative mean absolute error", + getRelativeMeanError()), + new Measurement("relative root mean squared error", + getRelativeSquareError()) + }; + } + + public double getTotalWeightObserved() { + return this.weightObserved; + } + + public double getMeanError() { + return this.weightObserved > 0.0 ? this.averageError + / this.weightObserved : 0.0; + } + + public double getSquareError() { + return Math.sqrt(this.weightObserved > 0.0 ? this.squareError + / this.weightObserved : 0.0); + } + + public double getTargetMeanError() { + return this.weightObserved > 0.0 ? this.averageTargetError + / this.weightObserved : 0.0; + } + + public double getTargetSquareError() { + return Math.sqrt(this.weightObserved > 0.0 ? this.squareTargetError + / this.weightObserved : 0.0); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + Measurement.getMeasurementsDescription(getPerformanceMeasurements(), + sb, indent); + } + + private double getRelativeMeanError() { + return this.averageTargetError > 0 ? + this.averageError / this.averageTargetError : 0.0; + } + + private double getRelativeSquareError() { + return Math.sqrt(this.squareTargetError > 0 ? + this.squareError / this.squareTargetError : 0.0); + } +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/ClassificationPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/ClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClassificationPerformanceEvaluator.java new file mode 100644 index 0000000..55fc553 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClassificationPerformanceEvaluator.java @@ -0,0 +1,24 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +public interface ClassificationPerformanceEvaluator extends PerformanceEvaluator { +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluationContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluationContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluationContentEvent.java new file mode 100644 index 0000000..67bdeec --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluationContentEvent.java @@ -0,0 +1,85 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +import org.apache.samoa.core.*; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.core.DataPoint; + +/** + * License + */ +/** + * The Class Clustering ResultEvent. + */ +final public class ClusteringEvaluationContentEvent implements ContentEvent { + + private static final long serialVersionUID = -7746983521296618922L; + private Clustering gtClustering; + private DataPoint dataPoint; + private final boolean isLast; + private String key = "0"; + + public ClusteringEvaluationContentEvent() { + this.isLast = false; + } + + public ClusteringEvaluationContentEvent(boolean isLast) { + this.isLast = isLast; + } + + /** + * Instantiates a new gtClustering result event. + * + * @param clustering + * the gtClustering result + * @param instance + * data point + * @param isLast + * is the last result + */ + public ClusteringEvaluationContentEvent(Clustering clustering, DataPoint instance, boolean isLast) { + this.gtClustering = clustering; + this.isLast = isLast; + this.dataPoint = instance; + } + + public String getKey() { + return key; + } + + public void setKey(String key) { + this.key = key; + } + + public boolean isLastEvent() { + return this.isLast; + } + + Clustering getGTClustering() { + return this.gtClustering; + } + + DataPoint getDataPoint() { + return this.dataPoint; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluatorProcessor.java new file mode 100644 index 0000000..6649d81 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringEvaluatorProcessor.java @@ -0,0 +1,320 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.evaluation.measures.SSQ; +import org.apache.samoa.evaluation.measures.StatisticalCollection; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.clusterers.KMeans; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.evaluation.LearningCurve; +import org.apache.samoa.moa.evaluation.LearningEvaluation; +import org.apache.samoa.moa.evaluation.MeasureCollection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ClusteringEvaluatorProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -2778051819116753612L; + + private static final Logger logger = LoggerFactory.getLogger(EvaluatorProcessor.class); + + private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances"; + + private final int samplingFrequency; + private final int decayHorizon; + private final File dumpFile; + private transient PrintStream immediateResultStream = null; + private transient boolean firstDump = true; + + private long totalCount = 0; + private long experimentStart = 0; + + private LearningCurve learningCurve; + + private MeasureCollection[] measures; + + private int id; + + protected Clustering gtClustering; + + protected ArrayList<DataPoint> points; + + private ClusteringEvaluatorProcessor(Builder builder) { + this.samplingFrequency = builder.samplingFrequency; + this.dumpFile = builder.dumpFile; + this.points = new ArrayList<>(); + this.decayHorizon = builder.decayHorizon; + } + + @Override + public boolean process(ContentEvent event) { + boolean ret = false; + if (event instanceof ClusteringResultContentEvent) { + ret = process((ClusteringResultContentEvent) event); + } + if (event instanceof ClusteringEvaluationContentEvent) { + ret = process((ClusteringEvaluationContentEvent) event); + } + return ret; + } + + private boolean process(ClusteringResultContentEvent result) { + // evaluate + Clustering clustering = KMeans.gaussianMeans(gtClustering, result.getClustering()); + for (MeasureCollection measure : measures) { + try { + measure.evaluateClusteringPerformance(clustering, gtClustering, points); + } catch (Exception ex) { + ex.printStackTrace(); + } + } + + this.addMeasurement(); + + if (result.isLastEvent()) { + this.concludeMeasurement(); + return true; + } + + totalCount += 1; + + if (totalCount == 1) { + experimentStart = System.nanoTime(); + } + + return false; + } + + private boolean process(ClusteringEvaluationContentEvent result) { + boolean ret = false; + if (result.getGTClustering() != null) { + gtClustering = result.getGTClustering(); + ret = true; + } + if (result.getDataPoint() != null) { + points.add(result.getDataPoint()); + if (points.size() > this.decayHorizon) { + points.remove(0); + } + ret = true; + } + return ret; + } + + @Override + public void onCreate(int id) { + this.id = id; + this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME); + // create the measure collection + measures = getMeasures(getMeasureSelection()); + + if (this.dumpFile != null) { + try { + if (dumpFile.exists()) { + this.immediateResultStream = new PrintStream(new FileOutputStream(dumpFile, true), true); + } else { + this.immediateResultStream = new PrintStream(new FileOutputStream(dumpFile), true); + } + + } catch (FileNotFoundException e) { + this.immediateResultStream = null; + logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); + + } catch (Exception e) { + this.immediateResultStream = null; + logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); + } + } + + this.firstDump = true; + } + + private static ArrayList<Class> getMeasureSelection() { + ArrayList<Class> mclasses = new ArrayList<>(); + // mclasses.add(EntropyCollection.class); + // mclasses.add(F1.class); + // mclasses.add(General.class); + // *mclasses.add(CMM.class); + mclasses.add(SSQ.class); + // *mclasses.add(SilhouetteCoefficient.class); + mclasses.add(StatisticalCollection.class); + // mclasses.add(Separation.class); + + return mclasses; + } + + private static MeasureCollection[] getMeasures(ArrayList<Class> measure_classes) { + MeasureCollection[] measures = new MeasureCollection[measure_classes.size()]; + for (int i = 0; i < measure_classes.size(); i++) { + try { + MeasureCollection m = (MeasureCollection) measure_classes.get(i).newInstance(); + measures[i] = m; + + } catch (Exception ex) { + java.util.logging.Logger.getLogger("Couldn't create Instance for " + measure_classes.get(i).getName()); + ex.printStackTrace(); + } + } + return measures; + } + + @Override + public Processor newProcessor(Processor p) { + ClusteringEvaluatorProcessor originalProcessor = (ClusteringEvaluatorProcessor) p; + ClusteringEvaluatorProcessor newProcessor = new ClusteringEvaluatorProcessor.Builder(originalProcessor).build(); + + if (originalProcessor.learningCurve != null) { + newProcessor.learningCurve = originalProcessor.learningCurve; + } + + return newProcessor; + } + + @Override + public String toString() { + StringBuilder report = new StringBuilder(); + + report.append(EvaluatorProcessor.class.getCanonicalName()); + report.append("id = ").append(this.id); + report.append('\n'); + + if (learningCurve.numEntries() > 0) { + report.append(learningCurve.toString()); + report.append('\n'); + } + return report.toString(); + } + + private void addMeasurement() { + // printMeasures(); + List<Measurement> measurements = new ArrayList<>(); + measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount * this.samplingFrequency)); + + addClusteringPerformanceMeasurements(measurements); + Measurement[] finalMeasurements = measurements.toArray(new Measurement[measurements.size()]); + + LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements); + learningCurve.insertEntry(learningEvaluation); + logger.debug("evaluator id = {}", this.id); + // logger.info(learningEvaluation.toString()); + + if (immediateResultStream != null) { + if (firstDump) { + immediateResultStream.println(learningCurve.headerToString()); + firstDump = false; + } + + immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1)); + immediateResultStream.flush(); + } + } + + private void addClusteringPerformanceMeasurements(List<Measurement> measurements) { + for (MeasureCollection measure : measures) { + for (int j = 0; j < measure.getNumMeasures(); j++) { + Measurement measurement = new Measurement(measure.getName(j), measure.getLastValue(j)); + measurements.add(measurement); + } + } + } + + private void concludeMeasurement() { + logger.info("last event is received!"); + logger.info("total count: {}", this.totalCount); + + String learningCurveSummary = this.toString(); + logger.info(learningCurveSummary); + + long experimentEnd = System.nanoTime(); + long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS); + logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount); + // logger.info("average throughput rate: {} instances/seconds", + // (totalCount/totalExperimentTime)); + } + + private void printMeasures() { + StringBuilder sb = new StringBuilder(); + for (MeasureCollection measure : measures) { + + sb.append("Mean ").append(measure.getClass().getSimpleName()).append(":").append(measure.getNumMeasures()) + .append("\n"); + for (int j = 0; j < measure.getNumMeasures(); j++) { + sb.append("[").append(measure.getName(j)).append("=").append(measure.getLastValue(j)).append("] \n"); + + } + sb.append("\n"); + } + + logger.debug("\n MEASURES: \n\n {}", sb.toString()); + System.out.println(sb.toString()); + } + + public static class Builder { + + private int samplingFrequency = 1000; + private File dumpFile = null; + private int decayHorizon = 1000; + + public Builder(int samplingFrequency) { + this.samplingFrequency = samplingFrequency; + } + + public Builder(ClusteringEvaluatorProcessor oldProcessor) { + this.samplingFrequency = oldProcessor.samplingFrequency; + this.dumpFile = oldProcessor.dumpFile; + this.decayHorizon = oldProcessor.decayHorizon; + } + + public Builder samplingFrequency(int samplingFrequency) { + this.samplingFrequency = samplingFrequency; + return this; + } + + public Builder decayHorizon(int decayHorizon) { + this.decayHorizon = decayHorizon; + return this; + } + + public Builder dumpFile(File file) { + this.dumpFile = file; + return this; + } + + public ClusteringEvaluatorProcessor build() { + return new ClusteringEvaluatorProcessor(this); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringResultContentEvent.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringResultContentEvent.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringResultContentEvent.java new file mode 100644 index 0000000..5d6494c --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/ClusteringResultContentEvent.java @@ -0,0 +1,75 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.moa.cluster.Clustering; + +/** + * License + */ +/** + * The Class Clustering ResultEvent. + */ +final public class ClusteringResultContentEvent implements ContentEvent { + + private static final long serialVersionUID = -7746983521296618922L; + private Clustering clustering; + private final boolean isLast; + private String key = "0"; + + public ClusteringResultContentEvent() { + this.isLast = false; + } + + public ClusteringResultContentEvent(boolean isLast) { + this.isLast = isLast; + } + + /** + * Instantiates a new clustering result event. + * + * @param clustering + * the clustering result + * @param isLast + * is the last result + */ + public ClusteringResultContentEvent(Clustering clustering, boolean isLast) { + this.clustering = clustering; + this.isLast = isLast; + } + + public String getKey() { + return key; + } + + public void setKey(String key) { + this.key = key; + } + + public boolean isLastEvent() { + return this.isLast; + } + + public Clustering getClustering() { + return this.clustering; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java new file mode 100644 index 0000000..6ec50dc --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/EvaluatorProcessor.java @@ -0,0 +1,231 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.PrintStream; +import java.util.Collections; +import java.util.List; +import java.util.Vector; +import java.util.concurrent.TimeUnit; + +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.learners.ResultContentEvent; +import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.evaluation.LearningCurve; +import org.apache.samoa.moa.evaluation.LearningEvaluation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class EvaluatorProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -2778051819116753612L; + + private static final Logger logger = + LoggerFactory.getLogger(EvaluatorProcessor.class); + + private static final String ORDERING_MEASUREMENT_NAME = "evaluation instances"; + + private final PerformanceEvaluator evaluator; + private final int samplingFrequency; + private final File dumpFile; + private transient PrintStream immediateResultStream = null; + private transient boolean firstDump = true; + + private long totalCount = 0; + private long experimentStart = 0; + + private long sampleStart = 0; + + private LearningCurve learningCurve; + private int id; + + private EvaluatorProcessor(Builder builder) { + this.evaluator = builder.evaluator; + this.samplingFrequency = builder.samplingFrequency; + this.dumpFile = builder.dumpFile; + } + + @Override + public boolean process(ContentEvent event) { + + ResultContentEvent result = (ResultContentEvent) event; + + if ((totalCount > 0) && (totalCount % samplingFrequency) == 0) { + long sampleEnd = System.nanoTime(); + long sampleDuration = TimeUnit.SECONDS.convert(sampleEnd - sampleStart, TimeUnit.NANOSECONDS); + sampleStart = sampleEnd; + + logger.info("{} seconds for {} instances", sampleDuration, samplingFrequency); + this.addMeasurement(); + } + + if (result.isLastEvent()) { + this.concludeMeasurement(); + return true; + } + + evaluator.addResult(result.getInstance(), result.getClassVotes()); + totalCount += 1; + + if (totalCount == 1) { + sampleStart = System.nanoTime(); + experimentStart = sampleStart; + } + + return false; + } + + @Override + public void onCreate(int id) { + this.id = id; + this.learningCurve = new LearningCurve(ORDERING_MEASUREMENT_NAME); + + if (this.dumpFile != null) { + try { + if (dumpFile.exists()) { + this.immediateResultStream = new PrintStream( + new FileOutputStream(dumpFile, true), true); + } else { + this.immediateResultStream = new PrintStream( + new FileOutputStream(dumpFile), true); + } + + } catch (FileNotFoundException e) { + this.immediateResultStream = null; + logger.error("File not found exception for {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); + + } catch (Exception e) { + this.immediateResultStream = null; + logger.error("Exception when creating {}:{}", this.dumpFile.getAbsolutePath(), e.toString()); + } + } + + this.firstDump = true; + } + + @Override + public Processor newProcessor(Processor p) { + EvaluatorProcessor originalProcessor = (EvaluatorProcessor) p; + EvaluatorProcessor newProcessor = new EvaluatorProcessor.Builder(originalProcessor).build(); + + if (originalProcessor.learningCurve != null) { + newProcessor.learningCurve = originalProcessor.learningCurve; + } + + return newProcessor; + } + + @Override + public String toString() { + StringBuilder report = new StringBuilder(); + + report.append(EvaluatorProcessor.class.getCanonicalName()); + report.append("id = ").append(this.id); + report.append('\n'); + + if (learningCurve.numEntries() > 0) { + report.append(learningCurve.toString()); + report.append('\n'); + } + return report.toString(); + } + + private void addMeasurement() { + List<Measurement> measurements = new Vector<>(); + measurements.add(new Measurement(ORDERING_MEASUREMENT_NAME, totalCount)); + + Collections.addAll(measurements, evaluator.getPerformanceMeasurements()); + + Measurement[] finalMeasurements = measurements.toArray(new Measurement[measurements.size()]); + + LearningEvaluation learningEvaluation = new LearningEvaluation(finalMeasurements); + learningCurve.insertEntry(learningEvaluation); + logger.debug("evaluator id = {}", this.id); + logger.info(learningEvaluation.toString()); + + if (immediateResultStream != null) { + if (firstDump) { + immediateResultStream.println(learningCurve.headerToString()); + firstDump = false; + } + + immediateResultStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1)); + immediateResultStream.flush(); + } + } + + private void concludeMeasurement() { + logger.info("last event is received!"); + logger.info("total count: {}", this.totalCount); + + String learningCurveSummary = this.toString(); + logger.info(learningCurveSummary); + + long experimentEnd = System.nanoTime(); + long totalExperimentTime = TimeUnit.SECONDS.convert(experimentEnd - experimentStart, TimeUnit.NANOSECONDS); + logger.info("total evaluation time: {} seconds for {} instances", totalExperimentTime, totalCount); + + if (immediateResultStream != null) { + immediateResultStream.println("# COMPLETED"); + immediateResultStream.flush(); + } + // logger.info("average throughput rate: {} instances/seconds", + // (totalCount/totalExperimentTime)); + } + + public static class Builder { + + private final PerformanceEvaluator evaluator; + private int samplingFrequency = 100000; + private File dumpFile = null; + + public Builder(PerformanceEvaluator evaluator) { + this.evaluator = evaluator; + } + + public Builder(EvaluatorProcessor oldProcessor) { + this.evaluator = oldProcessor.evaluator; + this.samplingFrequency = oldProcessor.samplingFrequency; + this.dumpFile = oldProcessor.dumpFile; + } + + public Builder samplingFrequency(int samplingFrequency) { + this.samplingFrequency = samplingFrequency; + return this; + } + + public Builder dumpFile(File file) { + this.dumpFile = file; + return this; + } + + public EvaluatorProcessor build() { + return new EvaluatorProcessor(this); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java new file mode 100644 index 0000000..0bd2450 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/PerformanceEvaluator.java @@ -0,0 +1,58 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.MOAObject; +import org.apache.samoa.moa.core.Measurement; + +/** + * Interface implemented by learner evaluators to monitor the results of the learning process. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public interface PerformanceEvaluator extends MOAObject { + + /** + * Resets this evaluator. It must be similar to starting a new evaluator from scratch. + * + */ + public void reset(); + + /** + * Adds a learning result to this evaluator. + * + * @param inst + * the instance to be classified + * @param classVotes + * an array containing the estimated membership probabilities of the test instance in each class + * @return an array of measurements monitored in this evaluator + */ + public void addResult(Instance inst, double[] classVotes); + + /** + * Gets the current measurements monitored by this evaluator. + * + * @return an array of measurements monitored by this evaluator + */ + public Measurement[] getPerformanceMeasurements(); +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/RegressionPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/RegressionPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/RegressionPerformanceEvaluator.java new file mode 100644 index 0000000..8042fee --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/RegressionPerformanceEvaluator.java @@ -0,0 +1,25 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +public interface RegressionPerformanceEvaluator extends PerformanceEvaluator { + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java new file mode 100644 index 0000000..c428a7f --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/WindowClassificationPerformanceEvaluator.java @@ -0,0 +1,219 @@ +package org.apache.samoa.evaluation; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Utils; +import org.apache.samoa.moa.AbstractMOAObject; +import org.apache.samoa.moa.core.Measurement; + +import com.github.javacliparser.IntOption; + +/** + * Classification evaluator that updates evaluation results using a sliding window. + * + * @author Albert Bifet (abifet at cs dot waikato dot ac dot nz) + * @version $Revision: 7 $ + */ +public class WindowClassificationPerformanceEvaluator extends AbstractMOAObject implements + ClassificationPerformanceEvaluator { + + private static final long serialVersionUID = 1L; + + public IntOption widthOption = new IntOption("width", + 'w', "Size of Window", 1000); + + protected double TotalweightObserved = 0; + + protected Estimator weightObserved; + + protected Estimator weightCorrect; + + protected Estimator weightCorrectNoChangeClassifier; + + protected double lastSeenClass; + + protected Estimator[] columnKappa; + + protected Estimator[] rowKappa; + + protected Estimator[] classAccuracy; + + protected int numClasses; + + public class Estimator { + + protected double[] window; + + protected int posWindow; + + protected int lenWindow; + + protected int SizeWindow; + + protected double sum; + + public Estimator(int sizeWindow) { + window = new double[sizeWindow]; + SizeWindow = sizeWindow; + posWindow = 0; + lenWindow = 0; + } + + public void add(double value) { + sum -= window[posWindow]; + sum += value; + window[posWindow] = value; + posWindow++; + if (posWindow == SizeWindow) { + posWindow = 0; + } + if (lenWindow < SizeWindow) { + lenWindow++; + } + } + + public double total() { + return sum; + } + + public double length() { + return lenWindow; + } + + } + + /* + * public void setWindowWidth(int w) { this.width = w; reset(); } + */ + @Override + public void reset() { + reset(this.numClasses); + } + + public void reset(int numClasses) { + this.numClasses = numClasses; + this.rowKappa = new Estimator[numClasses]; + this.columnKappa = new Estimator[numClasses]; + this.classAccuracy = new Estimator[numClasses]; + for (int i = 0; i < this.numClasses; i++) { + this.rowKappa[i] = new Estimator(this.widthOption.getValue()); + this.columnKappa[i] = new Estimator(this.widthOption.getValue()); + this.classAccuracy[i] = new Estimator(this.widthOption.getValue()); + } + this.weightCorrect = new Estimator(this.widthOption.getValue()); + this.weightCorrectNoChangeClassifier = new Estimator(this.widthOption.getValue()); + this.weightObserved = new Estimator(this.widthOption.getValue()); + this.TotalweightObserved = 0; + this.lastSeenClass = 0; + } + + @Override + public void addResult(Instance inst, double[] classVotes) { + double weight = inst.weight(); + int trueClass = (int) inst.classValue(); + if (weight > 0.0) { + if (TotalweightObserved == 0) { + reset(inst.numClasses()); + } + this.TotalweightObserved += weight; + this.weightObserved.add(weight); + int predictedClass = Utils.maxIndex(classVotes); + if (predictedClass == trueClass) { + this.weightCorrect.add(weight); + } else { + this.weightCorrect.add(0); + } + // Add Kappa statistic information + for (int i = 0; i < this.numClasses; i++) { + this.rowKappa[i].add(i == predictedClass ? weight : 0); + this.columnKappa[i].add(i == trueClass ? weight : 0); + } + if (this.lastSeenClass == trueClass) { + this.weightCorrectNoChangeClassifier.add(weight); + } else { + this.weightCorrectNoChangeClassifier.add(0); + } + this.classAccuracy[trueClass].add(predictedClass == trueClass ? weight : 0.0); + this.lastSeenClass = trueClass; + } + } + + @Override + public Measurement[] getPerformanceMeasurements() { + return new Measurement[] { + new Measurement("classified instances", + this.TotalweightObserved), + new Measurement("classifications correct (percent)", + getFractionCorrectlyClassified() * 100.0), + new Measurement("Kappa Statistic (percent)", + getKappaStatistic() * 100.0), + new Measurement("Kappa Temporal Statistic (percent)", + getKappaTemporalStatistic() * 100.0) + }; + + } + + public double getTotalWeightObserved() { + return this.weightObserved.total(); + } + + public double getFractionCorrectlyClassified() { + return this.weightObserved.total() > 0.0 ? this.weightCorrect.total() + / this.weightObserved.total() : 0.0; + } + + public double getKappaStatistic() { + if (this.weightObserved.total() > 0.0) { + double p0 = this.weightCorrect.total() / this.weightObserved.total(); + double pc = 0; + for (int i = 0; i < this.numClasses; i++) { + pc += (this.rowKappa[i].total() / this.weightObserved.total()) + * (this.columnKappa[i].total() / this.weightObserved.total()); + } + return (p0 - pc) / (1 - pc); + } else { + return 0; + } + } + + public double getKappaTemporalStatistic() { + if (this.weightObserved.total() > 0.0) { + double p0 = this.weightCorrect.total() / this.weightObserved.total(); + double pc = this.weightCorrectNoChangeClassifier.total() / this.weightObserved.total(); + + return (p0 - pc) / (1 - pc); + } else { + return 0; + } + } + + public double getFractionIncorrectlyClassified() { + return 1.0 - getFractionCorrectlyClassified(); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + Measurement.getMeasurementsDescription(getPerformanceMeasurements(), + sb, indent); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM.java new file mode 100644 index 0000000..9ea95f9 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM.java @@ -0,0 +1,515 @@ +package org.apache.samoa.evaluation.measures; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.ArrayList; + +import org.apache.samoa.evaluation.measures.CMM_GTAnalysis.CMMPoint; +import org.apache.samoa.moa.cluster.Cluster; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.cluster.SphereCluster; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.moa.evaluation.MeasureCollection; + +/** + * [CMM.java] + * + * CMM: Main class + * + * Reference: Kremer et al., "An Effective Evaluation Measure for Clustering on Evolving Data Streams", KDD, 2011 + * + * @author Timm jansen Data Management and Data Exploration Group, RWTH Aachen University + */ + +public class CMM extends MeasureCollection { + + private static final long serialVersionUID = 1L; + + /** + * found clustering + */ + private Clustering clustering; + + /** + * the ground truth analysis + */ + private CMM_GTAnalysis gtAnalysis; + + /** + * number of points within the horizon + */ + private int numPoints; + + /** + * number of clusters in the found clustering + */ + private int numFClusters; + + /** + * number of cluster in the adjusted groundtruth clustering that was calculated through the groundtruth analysis + */ + private int numGT0Classes; + + /** + * match found clusters to GT clusters + */ + private int matchMap[]; + + /** + * pointInclusionProbFC[p][C] contains the probability of point p being included in cluster C + */ + private double[][] pointInclusionProbFC; + + /** + * threshold that defines when a point is being considered belonging to a cluster + */ + private double pointInclusionProbThreshold = 0.5; + + /** + * parameterize the error weight of missed points (default 1) + */ + private double lamdaMissed = 1; + + /** + * enable/disable debug mode + */ + public boolean debug = false; + + /** + * enable/disable class merge (main feature of ground truth analysis) + */ + public boolean enableClassMerge = true; + + /** + * enable/disable model error when enabled errors that are caused by the underling cluster model will not be counted + */ + public boolean enableModelError = true; + + @Override + protected String[] getNames() { + String[] names = { "CMM", "CMM Basic", "CMM Missed", "CMM Misplaced", "CMM Noise", + "CA Seperability", "CA Noise", "CA Modell" }; + return names; + } + + @Override + protected boolean[] getDefaultEnabled() { + boolean[] defaults = { false, false, false, false, false, false, false, false }; + return defaults; + } + + @Override + public void evaluateClustering(Clustering clustering, Clustering trueClustering, ArrayList<DataPoint> points) + throws Exception { + this.clustering = clustering; + + numPoints = points.size(); + numFClusters = clustering.size(); + + gtAnalysis = new CMM_GTAnalysis(trueClustering, points, enableClassMerge); + + numGT0Classes = gtAnalysis.getNumberOfGT0Classes(); + + addValue("CA Seperability", gtAnalysis.getClassSeparability()); + addValue("CA Noise", gtAnalysis.getNoiseSeparability()); + addValue("CA Modell", gtAnalysis.getModelQuality()); + + /* init the matching and point distances */ + calculateMatching(); + + /* calculate the actual error */ + calculateError(); + } + + /** + * calculates the CMM specific matching between found clusters and ground truth clusters + */ + private void calculateMatching() { + + /** + * found cluster frequencies + */ + int[][] mapFC = new int[numFClusters][numGT0Classes]; + + /** + * ground truth cluster frequencies + */ + int[][] mapGT = new int[numGT0Classes][numGT0Classes]; + int[] sumsFC = new int[numFClusters]; + + // calculate fuzzy mapping from + pointInclusionProbFC = new double[numPoints][numFClusters]; + for (int p = 0; p < numPoints; p++) { + CMMPoint cmdp = gtAnalysis.getPoint(p); + // found cluster frequencies + for (int fc = 0; fc < numFClusters; fc++) { + Cluster cl = clustering.get(fc); + pointInclusionProbFC[p][fc] = cl.getInclusionProbability(cmdp); + if (pointInclusionProbFC[p][fc] >= pointInclusionProbThreshold) { + // make sure we don't count points twice that are contained in two + // merged clusters + if (cmdp.isNoise()) + continue; + mapFC[fc][cmdp.workclass()]++; + sumsFC[fc]++; + } + } + + // ground truth cluster frequencies + if (!cmdp.isNoise()) { + for (int hc = 0; hc < numGT0Classes; hc++) { + if (hc == cmdp.workclass()) { + mapGT[hc][hc]++; + } + else { + if (gtAnalysis.getGT0Cluster(hc).getInclusionProbability(cmdp) >= 1) { + mapGT[hc][cmdp.workclass()]++; + } + } + } + } + } + + // assign each found cluster to a hidden cluster + matchMap = new int[numFClusters]; + for (int fc = 0; fc < numFClusters; fc++) { + int matchIndex = -1; + // check if we only have one entry anyway + for (int hc0 = 0; hc0 < numGT0Classes; hc0++) { + if (mapFC[fc][hc0] != 0) { + if (matchIndex == -1) + matchIndex = hc0; + else { + matchIndex = -1; + break; + } + } + } + + // more then one entry, so look for most similar frequency profile + int minDiff = Integer.MAX_VALUE; + if (sumsFC[fc] != 0 && matchIndex == -1) { + ArrayList<Integer> fitCandidates = new ArrayList<Integer>(); + for (int hc0 = 0; hc0 < numGT0Classes; hc0++) { + int errDiff = 0; + for (int hc1 = 0; hc1 < numGT0Classes; hc1++) { + // fc profile doesn't fit into current hc profile + double freq_diff = mapFC[fc][hc1] - mapGT[hc0][hc1]; + if (freq_diff > 0) { + errDiff += freq_diff; + } + } + if (errDiff == 0) { + fitCandidates.add(hc0); + } + if (errDiff < minDiff) { + minDiff = errDiff; + matchIndex = hc0; + } + if (debug) { + // System.out.println("FC"+fc+"("+Arrays.toString(mapFC[fc])+") - HC0_"+hc0+"("+Arrays.toString(mapGT[hc0])+"):"+errDiff); + } + } + // if we have a fitting profile overwrite the min error choice + // if we have multiple fit candidates, use majority vote of + // corresponding classes + if (fitCandidates.size() != 0) { + int bestGTfit = fitCandidates.get(0); + for (int i = 1; i < fitCandidates.size(); i++) { + int GTfit = fitCandidates.get(i); + if (mapFC[fc][GTfit] > mapFC[fc][bestGTfit]) + bestGTfit = fitCandidates.get(i); + } + matchIndex = bestGTfit; + } + } + + matchMap[fc] = matchIndex; + int realMatch = -1; + if (matchIndex == -1) { + if (debug) + System.out.println("No cluster match: needs to be implemented?"); + } + else { + realMatch = gtAnalysis.getGT0Cluster(matchMap[fc]).getLabel(); + } + clustering.get(fc).setMeasureValue("CMM Match", "C" + realMatch); + clustering.get(fc).setMeasureValue("CMM Workclass", "C" + matchMap[fc]); + } + + // print matching table + if (debug) { + for (int i = 0; i < numFClusters; i++) { + System.out.print("C" + ((int) clustering.get(i).getId()) + " N:" + ((int) clustering.get(i).getWeight()) + + " | "); + for (int j = 0; j < numGT0Classes; j++) { + System.out.print(mapFC[i][j] + " "); + } + System.out.print(" = " + sumsFC[i] + " | "); + String match = "-"; + if (matchMap[i] != -1) { + match = Integer.toString(gtAnalysis.getGT0Cluster(matchMap[i]).getLabel()); + } + System.out.println(" --> " + match + "(work:" + matchMap[i] + ")"); + } + } + } + + /** + * Calculate the actual error values + */ + private void calculateError() { + int totalErrorCount = 0; + int totalRedundancy = 0; + int trueCoverage = 0; + int totalCoverage = 0; + + int numNoise = 0; + double errorNoise = 0; + double errorNoiseMax = 0; + + double errorMissed = 0; + double errorMissedMax = 0; + + double errorMisplaced = 0; + double errorMisplacedMax = 0; + + double totalError = 0.0; + double totalErrorMax = 0.0; + + /** + * mainly iterate over all points and find the right error value for the point. within the same run calculate + * various other stuff like coverage etc... + */ + for (int p = 0; p < numPoints; p++) { + CMMPoint cmdp = gtAnalysis.getPoint(p); + double weight = cmdp.weight(); + // noise counter + if (cmdp.isNoise()) { + numNoise++; + // this is always 1 + errorNoiseMax += cmdp.connectivity * weight; + } + else { + errorMissedMax += cmdp.connectivity * weight; + errorMisplacedMax += cmdp.connectivity * weight; + } + // sum up maxError as the individual errors are the quality weighted + // between 0-1 + totalErrorMax += cmdp.connectivity * weight; + + double err = 0; + int coverage = 0; + + // check every FCluster + for (int c = 0; c < numFClusters; c++) { + // contained in cluster c? + if (pointInclusionProbFC[p][c] >= pointInclusionProbThreshold) { + coverage++; + + if (!cmdp.isNoise()) { + // PLACED CORRECTLY + if (matchMap[c] == cmdp.workclass()) { + } + // MISPLACED + else { + double errvalue = misplacedError(cmdp, c); + if (errvalue > err) + err = errvalue; + } + } + else { + // NOISE + double errvalue = noiseError(cmdp, c); + if (errvalue > err) + err = errvalue; + } + } + } + // not in any cluster + if (coverage == 0) { + // MISSED + if (!cmdp.isNoise()) { + err = missedError(cmdp, true); + errorMissed += weight * err; + } + // NOISE + else { + } + } + else { + if (!cmdp.isNoise()) { + errorMisplaced += err * weight; + } + else { + errorNoise += err * weight; + } + } + + /* processing of other evaluation values */ + totalError += err * weight; + if (err != 0) + totalErrorCount++; + if (coverage > 0) + totalCoverage++; // points covered by clustering (incl. noise) + if (coverage > 0 && !cmdp.isNoise()) + trueCoverage++; // points covered by clustering, don't count noise + if (coverage > 1) + totalRedundancy++; // include noise + + cmdp.p.setMeasureValue("CMM", err); + cmdp.p.setMeasureValue("Redundancy", coverage); + } + + addValue("CMM", (totalErrorMax != 0) ? 1 - totalError / totalErrorMax : 1); + addValue("CMM Missed", (errorMissedMax != 0) ? 1 - errorMissed / errorMissedMax : 1); + addValue("CMM Misplaced", (errorMisplacedMax != 0) ? 1 - errorMisplaced / errorMisplacedMax : 1); + addValue("CMM Noise", (errorNoiseMax != 0) ? 1 - errorNoise / errorNoiseMax : 1); + addValue("CMM Basic", 1 - ((double) totalErrorCount / (double) numPoints)); + + if (debug) { + System.out.println("-------------"); + } + } + + private double noiseError(CMMPoint cmdp, int assignedClusterID) { + int gtAssignedID = matchMap[assignedClusterID]; + double error; + + // Cluster wasn't matched, so just contains noise + // TODO: Noiscluster? + // also happens when we decrease the radius and there is only a noise point + // in the center + if (gtAssignedID == -1) { + error = 1; + cmdp.p.setMeasureValue("CMM Type", "noise - cluster"); + } + else { + if (enableModelError + && gtAnalysis.getGT0Cluster(gtAssignedID).getInclusionProbability(cmdp) >= pointInclusionProbThreshold) { + // set to MIN_ERROR so we can still track the error + error = 0.00001; + cmdp.p.setMeasureValue("CMM Type", "noise - byModel"); + } + else { + error = 1 - gtAnalysis.getConnectionValue(cmdp, gtAssignedID); + cmdp.p.setMeasureValue("CMM Type", "noise"); + } + } + + return error; + } + + private double missedError(CMMPoint cmdp, boolean useHullDistance) { + cmdp.p.setMeasureValue("CMM Type", "missed"); + if (!useHullDistance) { + return cmdp.connectivity; + } + else { + // main idea: look at relative distance of missed point to cluster + double minHullDist = 1; + for (int fc = 0; fc < numFClusters; fc++) { + // if fc is mappend onto the class of the point, check it for its + // hulldist + if (matchMap[fc] != -1 && matchMap[fc] == cmdp.workclass()) { + if (clustering.get(fc) instanceof SphereCluster) { + SphereCluster sc = (SphereCluster) clustering.get(fc); + double distanceFC = sc.getCenterDistance(cmdp); + double radius = sc.getRadius(); + double hullDist = (distanceFC - radius) / (distanceFC + radius); + if (hullDist < minHullDist) + minHullDist = hullDist; + } + else { + double min = 1; + double max = 1; + + // TODO: distance for random shape + // generate X points from the cluster with + // clustering.get(fc).sample(null) + // and find Min and Max values + + double hullDist = min / max; + if (hullDist < minHullDist) + minHullDist = hullDist; + } + } + } + + // use distance as weight + if (minHullDist > 1) + minHullDist = 1; + + double weight = (1 - Math.exp(-lamdaMissed * minHullDist)); + cmdp.p.setMeasureValue("HullDistWeight", weight); + + return weight * cmdp.connectivity; + } + } + + private double misplacedError(CMMPoint cmdp, int assignedClusterID) { + double weight = 0; + + int gtAssignedID = matchMap[assignedClusterID]; + // TODO take care of noise cluster? + if (gtAssignedID == -1) { + System.out.println("Point " + cmdp.getTimestamp() + " from gtcluster " + cmdp.trueClass + + " assigned to noise cluster " + assignedClusterID); + return 1; + } + + if (gtAssignedID == cmdp.workclass()) + return 0; + else { + // assigned and real GT0 cluster are not connected, but does the model + // have the + // chance of separating this point after all? + if (enableModelError + && gtAnalysis.getGT0Cluster(gtAssignedID).getInclusionProbability(cmdp) >= pointInclusionProbThreshold) { + weight = 0; + cmdp.p.setMeasureValue("CMM Type", "missplaced - byModel"); + } + else { + // point was mapped onto wrong cluster (assigned), so check how far away + // the nearest point is within the wrongly assigned cluster + weight = 1 - gtAnalysis.getConnectionValue(cmdp, gtAssignedID); + } + } + double err_value; + // set to MIN_ERROR so we can still track the error + if (weight == 0) { + err_value = 0.00001; + } + else { + err_value = weight * cmdp.connectivity; + cmdp.p.setMeasureValue("CMM Type", "missplaced"); + } + + return err_value; + } + + public String getParameterString() { + String para = gtAnalysis.getParameterString(); + para += "lambdaMissed=" + lamdaMissed + ";"; + return para; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM_GTAnalysis.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM_GTAnalysis.java b/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM_GTAnalysis.java new file mode 100644 index 0000000..a48c054 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/evaluation/measures/CMM_GTAnalysis.java @@ -0,0 +1,847 @@ +package org.apache.samoa.evaluation.measures; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.core.AutoExpandVector; +import org.apache.samoa.moa.core.DataPoint; + +/** + * [CMM_GTAnalysis.java] + * + * CMM: Ground truth analysis + * + * Reference: Kremer et al., "An Effective Evaluation Measure for Clustering on Evolving Data Streams", KDD, 2011 + * + * @author Timm jansen Data Management and Data Exploration Group, RWTH Aachen University + */ + +/* + * TODO: - try to avoid calcualting the radius multiple times - avoid the full + * distance map? - knn functionality in clusters - noise error + */ +public class CMM_GTAnalysis { + + /** + * the given ground truth clustering + */ + private Clustering gtClustering; + + /** + * list of given points within the horizon + */ + private ArrayList<CMMPoint> cmmpoints; + + /** + * the newly calculate ground truth clustering + */ + private ArrayList<GTCluster> gt0Clusters; + + /** + * IDs of noise points + */ + private ArrayList<Integer> noise; + + /** + * total number of points + */ + private int numPoints; + + /** + * number of clusters of the original ground truth + */ + private int numGTClusters; + + /** + * number of classes of the original ground truth, in case of a micro clustering ground truth this differs from + * numGTClusters + */ + private int numGTClasses; + + /** + * number of classes after we are done with the analysis + */ + private int numGT0Classes; + + /** + * number of dimensions + */ + private int numDims; + + /** + * mapping between true cluster ID/class label of the original ground truth and the internal cluster ID/working class + * label. + * + * different original cluster IDs might map to the same new cluster ID due to merging of two clusters + */ + private HashMap<Integer, Integer> mapTrueLabelToWorkLabel; + + /** + * log of how clusters have been merged (for debugging) + */ + private int[] mergeMap; + + /** + * number of non-noise points that will create an error due to the underlying clustering model (e.g. point being + * covered by two clusters representing different classes) + */ + private int noiseErrorByModel; + + /** + * number of noise points that will create an error due to the underlying clustering model (e.g. noise point being + * covered by a cluster) + */ + private int pointErrorByModel; + + /** + * CMM debug mode + */ + private boolean debug = false; + + /******* CMM parameter ***********/ + + /** + * defines how many nearest neighbors will be used + */ + private int knnNeighbourhood = 2; + + /** + * the threshold which defines when ground truth clusters will be merged. set to 1 to disable merging + */ + private double tauConnection = 0.5; + + /** + * experimental (default: disabled) separate k for points to cluster and cluster to cluster + */ + private double clusterConnectionMaxPoints = knnNeighbourhood; + + /** + * experimental (default: disabled) use exponential connectivity function to model different behavior: closer points + * will have a stronger connection compared to the linear function. Use ConnRefXValue and ConnX to better parameterize + * lambda, which controls the decay of the connectivity + */ + private boolean useExpConnectivity = false; + private double lambdaConnRefXValue = 0.01; + private double lambdaConnX = 4; + private double lamdaConn; + + /******************************************/ + + /** + * Wrapper class for data points to store CMM relevant attributes + * + */ + protected class CMMPoint extends DataPoint { + /** + * Reference to original point + */ + protected DataPoint p = null; + + /** + * point ID + */ + protected int pID = 0; + + /** + * true class label + */ + protected int trueClass = -1; + + /** + * the connectivity of the point to its cluster + */ + protected double connectivity = 1.0; + + /** + * knn distnace within own cluster + */ + protected double knnInCluster = 0.0; + + /** + * knn indices (for debugging only) + */ + protected ArrayList<Integer> knnIndices; + + public CMMPoint(DataPoint point, int id) { + // make a copy, but keep reference + super(point, point.getTimestamp()); + p = point; + pID = id; + trueClass = (int) point.classValue(); + } + + /** + * Retruns the current working label of the cluster the point belongs to. The label can change due to merging of + * clusters. + * + * @return the current working class label + */ + protected int workclass() { + if (trueClass == -1) + return -1; + else + return mapTrueLabelToWorkLabel.get(trueClass); + } + } + + /** + * Main class to model the new clusters that will be the output of the cluster analysis + * + */ + protected class GTCluster { + /** points that are per definition in the cluster */ + private ArrayList<Integer> points = new ArrayList<Integer>(); + + /** + * a new GT cluster consists of one or more "old" GT clusters. Connected/overlapping clusters cannot be merged + * directly because of the underlying cluster model. E.g. for merging two spherical clusters the new cluster sphere + * can cover a lot more space then two separate smaller spheres. To keep the original coverage we need to keep the + * orignal clusters and merge them on an abstract level. + */ + private ArrayList<Integer> clusterRepresentations = new ArrayList<Integer>(); + + /** current work class (changes when merging) */ + private int workclass; + + /** original work class */ + private final int orgWorkClass; + + /** original class label */ + private final int label; + + /** clusters that have been merged into this cluster (debugging) */ + private ArrayList<Integer> mergedWorkLabels = null; + + /** average knn distance of all points in the cluster */ + private double knnMeanAvg = 0; + + /** average deviation of knn distance of all points */ + private double knnDevAvg = 0; + + /** connectivity of the cluster to all other clusters */ + private ArrayList<Double> connections = new ArrayList<Double>(); + + private GTCluster(int workclass, int label, int gtClusteringID) { + this.orgWorkClass = workclass; + this.workclass = workclass; + this.label = label; + this.clusterRepresentations.add(gtClusteringID); + } + + /** + * The original class label the cluster represents + * + * @return original class label + */ + protected int getLabel() { + return label; + } + + /** + * Calculate the probability of the point being covered through the cluster + * + * @param point + * to calculate the probability for + * @return probability of the point being covered through the cluster + */ + protected double getInclusionProbability(CMMPoint point) { + double prob = Double.MIN_VALUE; + // check all cluster representatives for coverage + for (int c = 0; c < clusterRepresentations.size(); c++) { + double tmp_prob = gtClustering.get(clusterRepresentations.get(c)).getInclusionProbability(point); + if (tmp_prob > prob) + prob = tmp_prob; + } + return prob; + } + + /** + * calculate knn distances of points within own cluster + average knn distance and average knn distance deviation of + * all points + */ + private void calculateKnn() { + for (int p0 : points) { + CMMPoint cmdp = cmmpoints.get(p0); + if (!cmdp.isNoise()) { + AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>(); + AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>(); + + // calculate nearest neighbours + getKnnInCluster(cmdp, knnNeighbourhood, points, knnDist, knnPointIndex); + + // TODO: What to do if we have less then k neighbours? + double avgKnn = 0; + for (int i = 0; i < knnDist.size(); i++) { + avgKnn += knnDist.get(i); + } + if (knnDist.size() != 0) + avgKnn /= knnDist.size(); + cmdp.knnInCluster = avgKnn; + cmdp.knnIndices = knnPointIndex; + cmdp.p.setMeasureValue("knnAvg", cmdp.knnInCluster); + + knnMeanAvg += avgKnn; + knnDevAvg += Math.pow(avgKnn, 2); + } + } + knnMeanAvg = knnMeanAvg / (double) points.size(); + knnDevAvg = knnDevAvg / (double) points.size(); + + double variance = knnDevAvg - Math.pow(knnMeanAvg, 2.0); + // Due to numerical errors, small negative values can occur. + if (variance <= 0.0) + variance = 1e-50; + knnDevAvg = Math.sqrt(variance); + + } + + /** + * Calculate the connection of a cluster to this cluster + * + * @param otherCid + * cluster id of the other cluster + * @param initial + * flag for initial run + */ + private void calculateClusterConnection(int otherCid, boolean initial) { + double avgConnection = 0; + if (workclass == otherCid) { + avgConnection = 1; + } + else { + AutoExpandVector<Double> kmax = new AutoExpandVector<Double>(); + AutoExpandVector<Integer> kmaxIndexes = new AutoExpandVector<Integer>(); + + for (int p : points) { + CMMPoint cmdp = cmmpoints.get(p); + double con_p_Cj = getConnectionValue(cmmpoints.get(p), otherCid); + double connection = cmdp.connectivity * con_p_Cj; + if (initial) { + cmdp.p.setMeasureValue("Connection to C" + otherCid, con_p_Cj); + } + + // connection + if (kmax.size() < clusterConnectionMaxPoints || connection > kmax.get(kmax.size() - 1)) { + int index = 0; + while (index < kmax.size() && connection < kmax.get(index)) { + index++; + } + kmax.add(index, connection); + kmaxIndexes.add(index, p); + if (kmax.size() > clusterConnectionMaxPoints) { + kmax.remove(kmax.size() - 1); + kmaxIndexes.add(kmaxIndexes.size() - 1); + } + } + } + // connection + for (int k = 0; k < kmax.size(); k++) { + avgConnection += kmax.get(k); + } + avgConnection /= kmax.size(); + } + + if (otherCid < connections.size()) { + connections.set(otherCid, avgConnection); + } + else if (connections.size() == otherCid) { + connections.add(avgConnection); + } + else + System.out.println("Something is going really wrong with the connection listing!" + knnNeighbourhood + " " + + tauConnection); + } + + /** + * Merge a cluster into this cluster + * + * @param mergeID + * the ID of the cluster to be merged + */ + private void mergeCluster(int mergeID) { + if (mergeID < gt0Clusters.size()) { + // track merging (debugging) + for (int i = 0; i < numGTClasses; i++) { + if (mergeMap[i] == mergeID) + mergeMap[i] = workclass; + if (mergeMap[i] > mergeID) + mergeMap[i]--; + } + GTCluster gtcMerge = gt0Clusters.get(mergeID); + if (debug) + System.out.println("Merging C" + gtcMerge.workclass + " into C" + workclass + + " with Con " + connections.get(mergeID) + " / " + gtcMerge.connections.get(workclass)); + + // update mapTrueLabelToWorkLabel + mapTrueLabelToWorkLabel.put(gtcMerge.label, workclass); + Iterator iterator = mapTrueLabelToWorkLabel.keySet().iterator(); + while (iterator.hasNext()) { + Integer key = (Integer) iterator.next(); + // update pointer of already merged cluster + int value = mapTrueLabelToWorkLabel.get(key); + if (value == mergeID) + mapTrueLabelToWorkLabel.put(key, workclass); + if (value > mergeID) + mapTrueLabelToWorkLabel.put(key, value - 1); + } + + // merge points from B into A + points.addAll(gtcMerge.points); + clusterRepresentations.addAll(gtcMerge.clusterRepresentations); + if (mergedWorkLabels == null) { + mergedWorkLabels = new ArrayList<Integer>(); + } + mergedWorkLabels.add(gtcMerge.orgWorkClass); + if (gtcMerge.mergedWorkLabels != null) + mergedWorkLabels.addAll(gtcMerge.mergedWorkLabels); + + gt0Clusters.remove(mergeID); + + // update workclass labels + for (int c = mergeID; c < gt0Clusters.size(); c++) { + gt0Clusters.get(c).workclass = c; + } + + // update knn distances + calculateKnn(); + for (int c = 0; c < gt0Clusters.size(); c++) { + gt0Clusters.get(c).connections.remove(mergeID); + + // recalculate connection from other clusters to the new merged one + gt0Clusters.get(c).calculateClusterConnection(workclass, false); + // and from new merged one to other clusters + gt0Clusters.get(workclass).calculateClusterConnection(c, false); + } + } + else { + System.out.println("Merge indices are not valid"); + } + } + } + + /** + * @param trueClustering + * the ground truth clustering + * @param points + * data points + * @param enableClassMerge + * allow class merging (should be set to true on default) + */ + public CMM_GTAnalysis(Clustering trueClustering, ArrayList<DataPoint> points, boolean enableClassMerge) { + if (debug) + System.out.println("GT Analysis Debug Output"); + + noiseErrorByModel = 0; + pointErrorByModel = 0; + if (!enableClassMerge) { + tauConnection = 1.0; + } + + lamdaConn = -Math.log(lambdaConnRefXValue) / Math.log(2) / lambdaConnX; + + this.gtClustering = trueClustering; + + numPoints = points.size(); + numDims = points.get(0).numAttributes() - 1; + numGTClusters = gtClustering.size(); + + // init mappings between work and true labels + mapTrueLabelToWorkLabel = new HashMap<Integer, Integer>(); + + // set up base of new clustering + gt0Clusters = new ArrayList<GTCluster>(); + int numWorkClasses = 0; + // create label to worklabel mapping as real labels can be just a set of + // unordered integers + for (int i = 0; i < numGTClusters; i++) { + int label = (int) gtClustering.get(i).getGroundTruth(); + if (!mapTrueLabelToWorkLabel.containsKey(label)) { + gt0Clusters.add(new GTCluster(numWorkClasses, label, i)); + mapTrueLabelToWorkLabel.put(label, numWorkClasses); + numWorkClasses++; + } + else { + gt0Clusters.get(mapTrueLabelToWorkLabel.get(label)).clusterRepresentations.add(i); + } + } + numGTClasses = numWorkClasses; + + mergeMap = new int[numGTClasses]; + for (int i = 0; i < numGTClasses; i++) { + mergeMap[i] = i; + } + + // create cmd point wrapper instances + cmmpoints = new ArrayList<CMMPoint>(); + for (int p = 0; p < points.size(); p++) { + CMMPoint cmdp = new CMMPoint(points.get(p), p); + cmmpoints.add(cmdp); + } + + // split points up into their GTClusters and Noise (according to class + // labels) + noise = new ArrayList<Integer>(); + for (int p = 0; p < numPoints; p++) { + if (cmmpoints.get(p).isNoise()) { + noise.add(p); + } + else { + gt0Clusters.get(cmmpoints.get(p).workclass()).points.add(p); + } + } + + // calculate initial knnMean and knnDev + for (GTCluster gtc : gt0Clusters) { + gtc.calculateKnn(); + } + + // calculate cluster connections + calculateGTClusterConnections(); + + // calculate point connections with own clusters + calculateGTPointQualities(); + + if (debug) + System.out.println("GT Analysis Debug End"); + + } + + /** + * Calculate the connection of a point to a cluster + * + * @param cmmp + * the point to calculate the connection for + * @param clusterID + * the corresponding cluster + * @return the connection value + */ + // TODO: Cache the connection value for a point to the different clusters??? + protected double getConnectionValue(CMMPoint cmmp, int clusterID) { + AutoExpandVector<Double> knnDist = new AutoExpandVector<Double>(); + AutoExpandVector<Integer> knnPointIndex = new AutoExpandVector<Integer>(); + + // calculate the knn distance of the point to the cluster + getKnnInCluster(cmmp, knnNeighbourhood, gt0Clusters.get(clusterID).points, knnDist, knnPointIndex); + + // TODO: What to do if we have less then k neighbors? + double avgDist = 0; + for (int i = 0; i < knnDist.size(); i++) { + avgDist += knnDist.get(i); + } + // what to do if we only have a single point??? + if (knnDist.size() != 0) + avgDist /= knnDist.size(); + else + return 0; + + // get the upper knn distance of the cluster + double upperKnn = gt0Clusters.get(clusterID).knnMeanAvg + gt0Clusters.get(clusterID).knnDevAvg; + + /* + * calculate the connectivity based on knn distance of the point within the + * cluster and the upper knn distance of the cluster + */ + if (avgDist < upperKnn) { + return 1; + } + else { + // value that should be reached at upperKnn distance + // Choose connection formula + double conn; + if (useExpConnectivity) + conn = Math.pow(2, -lamdaConn * (avgDist - upperKnn) / upperKnn); + else + conn = upperKnn / avgDist; + + if (Double.isNaN(conn)) + System.out.println("Connectivity NaN at " + cmmp.p.getTimestamp()); + + return conn; + } + } + + /** + * @param cmmp + * point to calculate knn distance for + * @param k + * number of nearest neighbors to look for + * @param pointIDs + * list of point IDs to check + * @param knnDist + * sorted list of smallest knn distances (can already be filled to make updates possible) + * @param knnPointIndex + * list of corresponding knn indices + */ + private void getKnnInCluster(CMMPoint cmmp, int k, + ArrayList<Integer> pointIDs, + AutoExpandVector<Double> knnDist, + AutoExpandVector<Integer> knnPointIndex) { + + // iterate over every point in the choosen cluster, cal distance and insert + // into list + for (int p1 = 0; p1 < pointIDs.size(); p1++) { + int pid = pointIDs.get(p1); + if (cmmp.pID == pid) + continue; + double dist = distance(cmmp, cmmpoints.get(pid)); + if (knnDist.size() < k || dist < knnDist.get(knnDist.size() - 1)) { + int index = 0; + while (index < knnDist.size() && dist > knnDist.get(index)) { + index++; + } + knnDist.add(index, dist); + knnPointIndex.add(index, pid); + if (knnDist.size() > k) { + knnDist.remove(knnDist.size() - 1); + knnPointIndex.remove(knnPointIndex.size() - 1); + } + } + } + } + + /** + * calculate initial connectivities + */ + private void calculateGTPointQualities() { + for (int p = 0; p < numPoints; p++) { + CMMPoint cmdp = cmmpoints.get(p); + if (!cmdp.isNoise()) { + cmdp.connectivity = getConnectionValue(cmdp, cmdp.workclass()); + cmdp.p.setMeasureValue("Connectivity", cmdp.connectivity); + } + } + } + + /** + * Calculate connections between clusters and merge clusters accordingly as long as connections exceed threshold + */ + private void calculateGTClusterConnections() { + for (int c0 = 0; c0 < gt0Clusters.size(); c0++) { + for (int c1 = 0; c1 < gt0Clusters.size(); c1++) { + gt0Clusters.get(c0).calculateClusterConnection(c1, true); + } + } + + boolean changedConnection = true; + while (changedConnection) { + if (debug) { + System.out.println("Cluster Connection"); + for (int c = 0; c < gt0Clusters.size(); c++) { + System.out.print("C" + gt0Clusters.get(c).label + " --> "); + for (int c1 = 0; c1 < gt0Clusters.get(c).connections.size(); c1++) { + System.out.print(" C" + gt0Clusters.get(c1).label + ": " + gt0Clusters.get(c).connections.get(c1)); + } + System.out.println(""); + } + System.out.println(""); + } + + double max = 0; + int maxIndexI = -1; + int maxIndexJ = -1; + + changedConnection = false; + for (int c0 = 0; c0 < gt0Clusters.size(); c0++) { + for (int c1 = c0 + 1; c1 < gt0Clusters.size(); c1++) { + if (c0 == c1) + continue; + double min = Math.min(gt0Clusters.get(c0).connections.get(c1), gt0Clusters.get(c1).connections.get(c0)); + if (min > max) { + max = min; + maxIndexI = c0; + maxIndexJ = c1; + } + } + } + if (maxIndexI != -1 && max > tauConnection) { + gt0Clusters.get(maxIndexI).mergeCluster(maxIndexJ); + if (debug) + System.out.println("Merging " + maxIndexI + " and " + maxIndexJ + " because of connection " + max); + + changedConnection = true; + } + } + numGT0Classes = gt0Clusters.size(); + } + + /** + * Calculates how well the original clusters are separable. Small values indicate bad separability, values close to 1 + * indicate good separability + * + * @return index of seperability + */ + public double getClassSeparability() { + // int totalConn = numGTClasses*(numGTClasses-1)/2; + // int mergedConn = 0; + // for(GTCluster gt : gt0Clusters){ + // int merged = gt.clusterRepresentations.size(); + // if(merged > 1) + // mergedConn+=merged * (merged-1)/2; + // } + // if(totalConn == 0) + // return 0; + // else + // return 1-mergedConn/(double)totalConn; + return numGT0Classes / (double) numGTClasses; + + } + + /** + * Calculates how well noise is separable from the given clusters Small values indicate bad separability, values close + * to 1 indicate good separability + * + * @return index of noise separability + */ + public double getNoiseSeparability() { + if (noise.isEmpty()) + return 1; + + double connectivity = 0; + for (int p : noise) { + CMMPoint npoint = cmmpoints.get(p); + double maxConnection = 0; + + // TODO: some kind of pruning possible. what about weighting? + for (int c = 0; c < gt0Clusters.size(); c++) { + double connection = getConnectionValue(npoint, c); + if (connection > maxConnection) + maxConnection = connection; + } + connectivity += maxConnection; + npoint.p.setMeasureValue("MaxConnection", maxConnection); + } + + return 1 - (connectivity / noise.size()); + } + + /** + * Calculates the relative number of errors being caused by the underlying cluster model + * + * @return quality of the model + */ + public double getModelQuality() { + for (int p = 0; p < numPoints; p++) { + CMMPoint cmdp = cmmpoints.get(p); + for (int hc = 0; hc < numGTClusters; hc++) { + if (gtClustering.get(hc).getGroundTruth() != cmdp.trueClass) { + if (gtClustering.get(hc).getInclusionProbability(cmdp) >= 1) { + if (!cmdp.isNoise()) + pointErrorByModel++; + else + noiseErrorByModel++; + break; + } + } + } + } + if (debug) + System.out.println("Error by model: noise " + noiseErrorByModel + " point " + pointErrorByModel); + + return 1 - ((pointErrorByModel + noiseErrorByModel) / (double) numPoints); + } + + /** + * Get CMM internal point + * + * @param index + * of the point + * @return cmm point + */ + protected CMMPoint getPoint(int index) { + return cmmpoints.get(index); + } + + /** + * Return cluster + * + * @param index + * of the cluster to return + * @return cluster + */ + protected GTCluster getGT0Cluster(int index) { + return gt0Clusters.get(index); + } + + /** + * Number of classes/clusters of the new clustering + * + * @return number of new clusters + */ + protected int getNumberOfGT0Classes() { + return numGT0Classes; + } + + /** + * Calculates Euclidian distance + * + * @param inst1 + * point as double array + * @param inst2 + * point as double array + * @return euclidian distance + */ + private double distance(Instance inst1, Instance inst2) { + return distance(inst1, inst2.toDoubleArray()); + + } + + /** + * Calculates Euclidian distance + * + * @param inst1 + * point as an instance + * @param inst2 + * point as double array + * @return euclidian distance + */ + private double distance(Instance inst1, double[] inst2) { + double distance = 0.0; + for (int i = 0; i < numDims; i++) { + double d = inst1.value(i) - inst2[i]; + distance += d * d; + } + return Math.sqrt(distance); + } + + /** + * String with main CMM parameters + * + * @return main CMM parameter + */ + public String getParameterString() { + String para = ""; + para += "k=" + knnNeighbourhood + ";"; + if (useExpConnectivity) { + para += "lambdaConnX=" + lambdaConnX + ";"; + para += "lambdaConn=" + lamdaConn + ";"; + para += "lambdaConnRef=" + lambdaConnRefXValue + ";"; + } + para += "m=" + clusterConnectionMaxPoints + ";"; + para += "tauConn=" + tauConnection + ";"; + + return para; + } +}
