http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClustreamClustererAdapter.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClustreamClustererAdapter.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClustreamClustererAdapter.java new file mode 100644 index 0000000..e0c1cb3 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/ClustreamClustererAdapter.java @@ -0,0 +1,171 @@ +package org.apache.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * License + */ +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.instances.InstancesHeader; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.clusterers.clustream.Clustream; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Base class for adapting Clustream clusterer. + * + */ +public class ClustreamClustererAdapter implements LocalClustererAdapter, Configurable { + + /** + * + */ + private static final long serialVersionUID = 4372366401338704353L; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Clusterer to train.", org.apache.samoa.moa.clusterers.Clusterer.class, Clustream.class.getName()); + /** + * The learner. + */ + protected org.apache.samoa.moa.clusterers.Clusterer learner; + + /** + * The is init. + */ + protected Boolean isInit; + + /** + * The dataset. + */ + protected Instances dataset; + + @Override + public void setDataset(Instances dataset) { + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner + * the learner + * @param dataset + * the dataset + */ + public ClustreamClustererAdapter(org.apache.samoa.moa.clusterers.Clusterer learner, Instances dataset) { + this.learner = learner.copy(); + this.isInit = false; + this.dataset = dataset; + } + + /** + * Instantiates a new learner. + * + * @param learner + * the learner + * @param dataset + * the dataset + */ + public ClustreamClustererAdapter() { + this.learner = ((org.apache.samoa.moa.clusterers.Clusterer) this.learnerOption.getValue()).copy(); + this.isInit = false; + // this.dataset = dataset; + } + + /** + * Creates a new learner object. + * + * @return the learner + */ + @Override + public ClustreamClustererAdapter create() { + ClustreamClustererAdapter l = new ClustreamClustererAdapter(learner, dataset); + if (dataset == null) { + System.out.println("dataset null while creating"); + } + return l; + } + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + @Override + public void trainOnInstance(Instance inst) { + if (this.isInit == false) { + this.isInit = true; + InstancesHeader instances = new InstancesHeader(dataset); + this.learner.setModelContext(instances); + this.learner.prepareForUse(); + } + if (inst.weight() > 0) { + inst.setDataset(dataset); + learner.trainOnInstance(inst); + } + } + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + @Override + public double[] getVotesForInstance(Instance inst) { + double[] ret; + inst.setDataset(dataset); + if (this.isInit == false) { + ret = new double[dataset.numClasses()]; + } else { + ret = learner.getVotesForInstance(inst); + } + return ret; + } + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + * + */ + @Override + public void resetLearning() { + learner.resetLearning(); + } + + public boolean implementsMicroClusterer() { + return this.learner.implementsMicroClusterer(); + } + + public Clustering getMicroClusteringResult() { + return this.learner.getMicroClusteringResult(); + } + + public Instances getDataset() { + return this.dataset; + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererAdapter.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererAdapter.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererAdapter.java new file mode 100644 index 0000000..4a3e5e9 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererAdapter.java @@ -0,0 +1,81 @@ +package org.apache.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.moa.cluster.Clustering; + +/** + * Learner interface for non-distributed learners. + * + * @author abifet + */ +public interface LocalClustererAdapter extends Serializable { + + /** + * Creates a new learner object. + * + * @return the learner + */ + LocalClustererAdapter create(); + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + double[] getVotesForInstance(Instance inst); + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. + * + */ + void resetLearning(); + + /** + * Trains this classifier incrementally using the given instance. + * + * @param inst + * the instance to be used for training + */ + void trainOnInstance(Instance inst); + + /** + * Sets where to obtain the information of attributes of Instances + * + * @param dataset + * the dataset that contains the information + */ + public void setDataset(Instances dataset); + + public Instances getDataset(); + + public boolean implementsMicroClusterer(); + + public Clustering getMicroClusteringResult(); + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererProcessor.java new file mode 100644 index 0000000..163184f --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/LocalClustererProcessor.java @@ -0,0 +1,200 @@ +package org.apache.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * License + */ +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.evaluation.ClusteringEvaluationContentEvent; +import org.apache.samoa.evaluation.ClusteringResultContentEvent; +import org.apache.samoa.instances.DenseInstance; +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.cluster.Clustering; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.topology.Stream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +//import weka.core.Instance; + +/** + * The Class LearnerProcessor. + */ +final public class LocalClustererProcessor implements Processor { + + /** + * + */ + private static final long serialVersionUID = -1577910988699148691L; + private static final Logger logger = LoggerFactory + .getLogger(LocalClustererProcessor.class); + private LocalClustererAdapter model; + private Stream outputStream; + private int modelId; + private long instancesCount = 0; + private long sampleFrequency = 1000; + + public long getSampleFrequency() { + return sampleFrequency; + } + + public void setSampleFrequency(long sampleFrequency) { + this.sampleFrequency = sampleFrequency; + } + + /** + * Sets the learner. + * + * @param model + * the model to set + */ + public void setLearner(LocalClustererAdapter model) { + this.model = model; + } + + /** + * Gets the learner. + * + * @return the model + */ + public LocalClustererAdapter getLearner() { + return model; + } + + /** + * Set the output streams. + * + * @param outputStream + * the new output stream {@link PredictionCombinerPE}. + */ + public void setOutputStream(Stream outputStream) { + + this.outputStream = outputStream; + } + + /** + * Gets the output stream. + * + * @return the output stream + */ + public Stream getOutputStream() { + return outputStream; + } + + /** + * Gets the instances count. + * + * @return number of observation vectors used in training iteration. + */ + public long getInstancesCount() { + return instancesCount; + } + + /** + * Update stats. + * + * @param event + * the event + */ + private void updateStats(ContentEvent event) { + Instance instance; + if (event instanceof ClusteringContentEvent) { + // Local Clustering + ClusteringContentEvent ev = (ClusteringContentEvent) event; + instance = ev.getInstance(); + DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey())); + model.trainOnInstance(point); + instancesCount++; + } + + if (event instanceof ClusteringResultContentEvent) { + // Global Clustering + ClusteringResultContentEvent ev = (ClusteringResultContentEvent) event; + Clustering clustering = ev.getClustering(); + + for (int i = 0; i < clustering.size(); i++) { + instance = new DenseInstance(1.0, clustering.get(i).getCenter()); + instance.setDataset(model.getDataset()); + DataPoint point = new DataPoint(instance, Integer.parseInt(event.getKey())); + model.trainOnInstance(point); + instancesCount++; + } + } + + if (instancesCount % this.sampleFrequency == 0) { + logger.info("Trained model using {} events with classifier id {}", instancesCount, this.modelId); // getId()); + } + } + + /** + * On event. + * + * @param event + * the event + * @return true, if successful + */ + @Override + public boolean process(ContentEvent event) { + + if (event.isLastEvent() || + (instancesCount > 0 && instancesCount % this.sampleFrequency == 0)) { + if (model.implementsMicroClusterer()) { + + Clustering clustering = model.getMicroClusteringResult(); + + ClusteringResultContentEvent resultEvent = new ClusteringResultContentEvent(clustering, event.isLastEvent()); + + this.outputStream.put(resultEvent); + } + } + + updateStats(event); + return false; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#onCreate(int) + */ + @Override + public void onCreate(int id) { + this.modelId = id; + model = model.create(); + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + LocalClustererProcessor newProcessor = new LocalClustererProcessor(); + LocalClustererProcessor originProcessor = (LocalClustererProcessor) sourceProcessor; + if (originProcessor.getLearner() != null) { + newProcessor.setLearner(originProcessor.getLearner().create()); + } + newProcessor.setOutputStream(originProcessor.getOutputStream()); + return newProcessor; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/SingleLearner.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/SingleLearner.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/SingleLearner.java new file mode 100644 index 0000000..f6173f6 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/SingleLearner.java @@ -0,0 +1,102 @@ +package org.apache.samoa.learners.clusterers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; + +/** + * + * Learner that contain a single learner. + * + */ +public final class SingleLearner implements Learner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private LocalClustererProcessor learnerP; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Learner to train.", LocalClustererAdapter.class, ClustreamClustererAdapter.class.getName()); + + private TopologyBuilder builder; + + private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + + this.builder.addProcessor(learnerP, this.parallelism); + resultStream = this.builder.createStream(learnerP); + + learnerP.setOutputStream(resultStream); + } + + /* + * (non-Javadoc) + * + * @see samoa.classifiers.Classifier#getInputProcessingItem() + */ + @Override + public Processor getInputProcessor() { + return learnerP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.resultStream); + return streams; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java new file mode 100644 index 0000000..2b3e01c --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java @@ -0,0 +1,100 @@ +package org.apache.samoa.learners.clusterers.simple; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +/** + * License + */ +import org.apache.samoa.core.ContentEvent; +import org.apache.samoa.core.Processor; +import org.apache.samoa.evaluation.ClusteringEvaluationContentEvent; +import org.apache.samoa.learners.clusterers.ClusteringContentEvent; +import org.apache.samoa.moa.core.DataPoint; +import org.apache.samoa.topology.Stream; + +/** + * The Class ClusteringDistributorPE. + */ +public class ClusteringDistributorProcessor implements Processor { + + private static final long serialVersionUID = -1550901409625192730L; + + private Stream outputStream; + private Stream evaluationStream; + private int numInstances; + + public Stream getOutputStream() { + return outputStream; + } + + public void setOutputStream(Stream outputStream) { + this.outputStream = outputStream; + } + + public Stream getEvaluationStream() { + return evaluationStream; + } + + public void setEvaluationStream(Stream evaluationStream) { + this.evaluationStream = evaluationStream; + } + + /** + * Process event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + // distinguish between ClusteringContentEvent and + // ClusteringEvaluationContentEvent + if (event instanceof ClusteringContentEvent) { + ClusteringContentEvent cce = (ClusteringContentEvent) event; + outputStream.put(event); + if (cce.isSample()) { + evaluationStream.put(new ClusteringEvaluationContentEvent(null, + new DataPoint(cce.getInstance(), numInstances++), cce.isLastEvent())); + } + } else if (event instanceof ClusteringEvaluationContentEvent) { + evaluationStream.put(event); + } + return true; + } + + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + ClusteringDistributorProcessor newProcessor = new ClusteringDistributorProcessor(); + ClusteringDistributorProcessor originProcessor = (ClusteringDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStream() != null) + newProcessor.setOutputStream(originProcessor.getOutputStream()); + if (originProcessor.getEvaluationStream() != null) + newProcessor.setEvaluationStream(originProcessor.getEvaluationStream()); + return newProcessor; + } + + public void onCreate(int id) { + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/DistributedClusterer.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/DistributedClusterer.java b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/DistributedClusterer.java new file mode 100644 index 0000000..8f3537a --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/learners/clusterers/simple/DistributedClusterer.java @@ -0,0 +1,121 @@ +package org.apache.samoa.learners.clusterers.simple; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * License + */ + +import com.google.common.collect.ImmutableSet; + +import java.util.Set; + +import org.apache.samoa.core.Processor; +import org.apache.samoa.instances.Instances; +import org.apache.samoa.learners.Learner; +import org.apache.samoa.learners.clusterers.*; +import org.apache.samoa.topology.ProcessingItem; +import org.apache.samoa.topology.Stream; +import org.apache.samoa.topology.TopologyBuilder; + +import com.github.javacliparser.ClassOption; +import com.github.javacliparser.Configurable; +import com.github.javacliparser.IntOption; + +/** + * + * Learner that contain a single learner. + * + */ +public final class DistributedClusterer implements Learner, Configurable { + + private static final long serialVersionUID = 684111382631697031L; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', "Clusterer to use.", LocalClustererAdapter.class, + ClustreamClustererAdapter.class.getName()); + + public IntOption paralellismOption = new IntOption("paralellismOption", 'P', + "The paralellism level for concurrent processes", 2, 1, Integer.MAX_VALUE); + + private TopologyBuilder builder; + + // private ClusteringDistributorProcessor distributorP; + private LocalClustererProcessor learnerP; + + // private Stream distributorToLocalStream; + private Stream localToGlobalStream; + + // private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + // this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + // Distributor + // distributorP = new ClusteringDistributorProcessor(); + // this.builder.addProcessor(distributorP, parallelism); + // distributorToLocalStream = this.builder.createStream(distributorP); + // distributorP.setOutputStream(distributorToLocalStream); + // distributorToGlobalStream = this.builder.createStream(distributorP); + + // Local Clustering + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + builder.addProcessor(learnerP, this.paralellismOption.getValue()); + localToGlobalStream = this.builder.createStream(learnerP); + learnerP.setOutputStream(localToGlobalStream); + + // Global Clustering + LocalClustererProcessor globalClusteringCombinerP = new LocalClustererProcessor(); + LocalClustererAdapter globalLearner = (LocalClustererAdapter) this.learnerOption.getValue(); + globalLearner.setDataset(this.dataset); + globalClusteringCombinerP.setLearner(learner); + builder.addProcessor(globalClusteringCombinerP, 1); + builder.connectInputAllStream(localToGlobalStream, globalClusteringCombinerP); + + // Output Stream + resultStream = this.builder.createStream(globalClusteringCombinerP); + globalClusteringCombinerP.setOutputStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + // return distributorP; + return learnerP; + } + + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.resultStream); + return streams; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/AbstractMOAObject.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/AbstractMOAObject.java b/samoa-api/src/main/java/org/apache/samoa/moa/AbstractMOAObject.java new file mode 100644 index 0000000..2315fde --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/AbstractMOAObject.java @@ -0,0 +1,83 @@ +package org.apache.samoa.moa; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.moa.core.SerializeUtils; + +//import moa.core.SizeOf; + +/** + * Abstract MOA Object. All classes that are serializable, copiable, can measure its size, and can give a description, + * extend this class. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public abstract class AbstractMOAObject implements MOAObject { + + @Override + public MOAObject copy() { + return copy(this); + } + + @Override + public int measureByteSize() { + return measureByteSize(this); + } + + /** + * Returns a description of the object. + * + * @return a description of the object + */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + getDescription(sb, 0); + return sb.toString(); + } + + /** + * This method produces a copy of an object. + * + * @param obj + * object to copy + * @return a copy of the object + */ + public static MOAObject copy(MOAObject obj) { + try { + return (MOAObject) SerializeUtils.copyObject(obj); + } catch (Exception e) { + throw new RuntimeException("Object copy failed.", e); + } + } + + /** + * Gets the memory size of an object. + * + * @param obj + * object to measure the memory size + * @return the memory size of this object + */ + public static int measureByteSize(MOAObject obj) { + return 0; // (int) SizeOf.fullSizeOf(obj); + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/MOAObject.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/MOAObject.java b/samoa-api/src/main/java/org/apache/samoa/moa/MOAObject.java new file mode 100644 index 0000000..60fc28b --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/MOAObject.java @@ -0,0 +1,58 @@ +package org.apache.samoa.moa; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +/** + * Interface implemented by classes in MOA, so that all are serializable, can produce copies of their objects, and can + * measure its memory size. They also give a string description. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public interface MOAObject extends Serializable { + + /** + * Gets the memory size of this object. + * + * @return the memory size of this object + */ + public int measureByteSize(); + + /** + * This method produces a copy of this object. + * + * @return a copy of this object + */ + public MOAObject copy(); + + /** + * Returns a string representation of this object. Used in <code>AbstractMOAObject.toString</code> to give a string + * representation of the object. + * + * @param sb + * the stringbuilder to add the description + * @param indent + * the number of characters to indent + */ + public void getDescription(StringBuilder sb, int indent); +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/AbstractClassifier.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/AbstractClassifier.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/AbstractClassifier.java new file mode 100644 index 0000000..4ca4c0f --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/AbstractClassifier.java @@ -0,0 +1,379 @@ +package org.apache.samoa.moa.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.instances.InstancesHeader; +import org.apache.samoa.moa.MOAObject; +import org.apache.samoa.moa.core.Example; +import org.apache.samoa.moa.core.Measurement; +import org.apache.samoa.moa.core.ObjectRepository; +import org.apache.samoa.moa.core.StringUtils; +import org.apache.samoa.moa.core.Utils; +import org.apache.samoa.moa.learners.Learner; +import org.apache.samoa.moa.options.AbstractOptionHandler; +import org.apache.samoa.moa.tasks.TaskMonitor; + +import com.github.javacliparser.IntOption; + +public abstract class AbstractClassifier extends AbstractOptionHandler implements Classifier { + + @Override + public String getPurposeString() { + return "MOA Classifier: " + getClass().getCanonicalName(); + } + + /** Header of the instances of the data stream */ + protected InstancesHeader modelContext; + + /** Sum of the weights of the instances trained by this model */ + protected double trainingWeightSeenByModel = 0.0; + + /** Random seed used in randomizable learners */ + protected int randomSeed = 1; + + /** Option for randomizable learners to change the random seed */ + protected IntOption randomSeedOption; + + /** Random Generator used in randomizable learners */ + public Random classifierRandom; + + /** + * Creates an classifier and setups the random seed option if the classifier is randomizable. + */ + public AbstractClassifier() { + if (isRandomizable()) { + this.randomSeedOption = new IntOption("randomSeed", 'r', + "Seed for random behaviour of the classifier.", 1); + } + } + + @Override + public void prepareForUseImpl(TaskMonitor monitor, + ObjectRepository repository) { + if (this.randomSeedOption != null) { + this.randomSeed = this.randomSeedOption.getValue(); + } + if (!trainingHasStarted()) { + resetLearning(); + } + } + + @Override + public double[] getVotesForInstance(Example<Instance> example) { + return getVotesForInstance(example.getData()); + } + + @Override + public abstract double[] getVotesForInstance(Instance inst); + + @Override + public void setModelContext(InstancesHeader ih) { + if ((ih != null) && (ih.classIndex() < 0)) { + throw new IllegalArgumentException( + "Context for a classifier must include a class to learn"); + } + if (trainingHasStarted() + && (this.modelContext != null) + && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) { + throw new IllegalArgumentException( + "New context is not compatible with existing model"); + } + this.modelContext = ih; + } + + @Override + public InstancesHeader getModelContext() { + return this.modelContext; + } + + @Override + public void setRandomSeed(int s) { + this.randomSeed = s; + if (this.randomSeedOption != null) { + // keep option consistent + this.randomSeedOption.setValue(s); + } + } + + @Override + public boolean trainingHasStarted() { + return this.trainingWeightSeenByModel > 0.0; + } + + @Override + public double trainingWeightSeenByModel() { + return this.trainingWeightSeenByModel; + } + + @Override + public void resetLearning() { + this.trainingWeightSeenByModel = 0.0; + if (isRandomizable()) { + this.classifierRandom = new Random(this.randomSeed); + } + resetLearningImpl(); + } + + @Override + public void trainOnInstance(Instance inst) { + if (inst.weight() > 0.0) { + this.trainingWeightSeenByModel += inst.weight(); + trainOnInstanceImpl(inst); + } + } + + @Override + public Measurement[] getModelMeasurements() { + List<Measurement> measurementList = new LinkedList<>(); + measurementList.add(new Measurement("model training instances", + trainingWeightSeenByModel())); + measurementList.add(new Measurement("model serialized size (bytes)", + measureByteSize())); + Measurement[] modelMeasurements = getModelMeasurementsImpl(); + if (modelMeasurements != null) { + measurementList.addAll(Arrays.asList(modelMeasurements)); + } + // add average of sub-model measurements + Learner[] subModels = getSublearners(); + if ((subModels != null) && (subModels.length > 0)) { + List<Measurement[]> subMeasurements = new LinkedList<>(); + for (Learner subModel : subModels) { + if (subModel != null) { + subMeasurements.add(subModel.getModelMeasurements()); + } + } + Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements + .toArray(new Measurement[subMeasurements.size()][])); + measurementList.addAll(Arrays.asList(avgMeasurements)); + } + return measurementList.toArray(new Measurement[measurementList.size()]); + } + + @Override + public void getDescription(StringBuilder out, int indent) { + StringUtils.appendIndented(out, indent, "Model type: "); + out.append(this.getClass().getName()); + StringUtils.appendNewline(out); + Measurement.getMeasurementsDescription(getModelMeasurements(), out, + indent); + StringUtils.appendNewlineIndented(out, indent, "Model description:"); + StringUtils.appendNewline(out); + if (trainingHasStarted()) { + getModelDescription(out, indent); + } else { + StringUtils.appendIndented(out, indent, + "Model has not been trained."); + } + } + + @Override + public Learner[] getSublearners() { + return null; + } + + @Override + public Classifier[] getSubClassifiers() { + return null; + } + + @Override + public Classifier copy() { + return (Classifier) super.copy(); + } + + @Override + public MOAObject getModel() { + return this; + } + + @Override + public void trainOnInstance(Example<Instance> example) { + trainOnInstance(example.getData()); + } + + @Override + public boolean correctlyClassifies(Instance inst) { + return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue(); + } + + /** + * Gets the name of the attribute of the class from the header. + * + * @return the string with name of the attribute of the class + */ + public String getClassNameString() { + return InstancesHeader.getClassNameString(this.modelContext); + } + + /** + * Gets the name of a label of the class from the header. + * + * @param classLabelIndex + * the label index + * @return the name of the label of the class + */ + public String getClassLabelString(int classLabelIndex) { + return InstancesHeader.getClassLabelString(this.modelContext, + classLabelIndex); + } + + /** + * Gets the name of an attribute from the header. + * + * @param attIndex + * the attribute index + * @return the name of the attribute + */ + public String getAttributeNameString(int attIndex) { + return InstancesHeader.getAttributeNameString(this.modelContext, attIndex); + } + + /** + * Gets the name of a value of an attribute from the header. + * + * @param attIndex + * the attribute index + * @param valIndex + * the value of the attribute + * @return the name of the value of the attribute + */ + public String getNominalValueString(int attIndex, int valIndex) { + return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex); + } + + /** + * Returns if two contexts or headers of instances are compatible.<br> + * <br> + * + * Two contexts are compatible if they follow the following rules:<br> + * Rule 1: num classes can increase but never decrease<br> + * Rule 2: num attributes can increase but never decrease<br> + * Rule 3: num nominal attribute values can increase but never decrease<br> + * Rule 4: attribute types must stay in the same order (although class can move; is always skipped over)<br> + * <br> + * + * Attribute names are free to change, but should always still represent the original attributes. + * + * @param originalContext + * the first context to compare + * @param newContext + * the second context to compare + * @return true if the two contexts are compatible. + */ + public static boolean contextIsCompatible(InstancesHeader originalContext, + InstancesHeader newContext) { + + if (newContext.numClasses() < originalContext.numClasses()) { + return false; // rule 1 + } + if (newContext.numAttributes() < originalContext.numAttributes()) { + return false; // rule 2 + } + int oPos = 0; + int nPos = 0; + while (oPos < originalContext.numAttributes()) { + if (oPos == originalContext.classIndex()) { + oPos++; + if (!(oPos < originalContext.numAttributes())) { + break; + } + } + if (nPos == newContext.classIndex()) { + nPos++; + } + if (originalContext.attribute(oPos).isNominal()) { + if (!newContext.attribute(nPos).isNominal()) { + return false; // rule 4 + } + if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) { + return false; // rule 3 + } + } else { + assert (originalContext.attribute(oPos).isNumeric()); + if (!newContext.attribute(nPos).isNumeric()) { + return false; // rule 4 + } + } + oPos++; + nPos++; + } + return true; // all checks clear + } + + /** + * Resets this classifier. It must be similar to starting a new classifier from scratch. <br> + * <br> + * + * The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in + * overridden methods. Note that this will produce compiler errors if not overridden. + */ + public abstract void resetLearningImpl(); + + /** + * Trains this classifier incrementally using the given instance.<br> + * <br> + * + * The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in + * overridden methods. Note that this will produce compiler errors if not overridden. + * + * @param inst + * the instance to be used for training + */ + public abstract void trainOnInstanceImpl(Instance inst); + + /** + * Gets the current measurements of this classifier.<br> + * <br> + * + * The reason for ...Impl methods: ease programmer burden by not requiring them to remember calls to super in + * overridden methods. Note that this will produce compiler errors if not overridden. + * + * @return an array of measurements to be used in evaluation tasks + */ + protected abstract Measurement[] getModelMeasurementsImpl(); + + /** + * Returns a string representation of the model. + * + * @param out + * the stringbuilder to add the description + * @param indent + * the number of characters to indent + */ + public abstract void getModelDescription(StringBuilder out, int indent); + + /** + * Gets the index of the attribute in the instance, given the index of the attribute in the learner. + * + * @param index + * the index of the attribute in the learner + * @return the index in the instance + */ + protected static int modelAttIndexToInstanceAttIndex(int index) { + return index; // inst.classIndex() > index ? index : index + 1; + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Classifier.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Classifier.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Classifier.java new file mode 100644 index 0000000..cd40888 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Classifier.java @@ -0,0 +1,77 @@ +package org.apache.samoa.moa.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.instances.Instance; +import org.apache.samoa.moa.core.Example; +import org.apache.samoa.moa.learners.Learner; + +/** + * Classifier interface for incremental classification models. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public interface Classifier extends Learner<Example<Instance>> { + + /** + * Gets the classifiers of this ensemble. Returns null if this learner is a single learner. + * + * @return an array of the learners of the ensemble + */ + public Classifier[] getSubClassifiers(); + + /** + * Produces a copy of this learner. + * + * @return the copy of this learner + */ + public Classifier copy(); + + /** + * Gets whether this classifier correctly classifies an instance. Uses getVotesForInstance to obtain the prediction + * and the instance to obtain its true class. + * + * + * @param inst + * the instance to be classified + * @return true if the instance is correctly classified + */ + public boolean correctlyClassifies(Instance inst); + + /** + * Trains this learner incrementally using the given example. + * + * @param inst + * the instance to be used for training + */ + public void trainOnInstance(Instance inst); + + /** + * Predicts the class memberships for a given instance. If an instance is unclassified, the returned array elements + * must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the test instance in each class + */ + public double[] getVotesForInstance(Instance inst); +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Regressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Regressor.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Regressor.java new file mode 100644 index 0000000..5baf627 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/Regressor.java @@ -0,0 +1,31 @@ +package org.apache.samoa.moa.classifiers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * Regressor interface for incremental regression models. It is used only in the GUI Regression Tab. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public interface Regressor { + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/AttributeSplitSuggestion.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/AttributeSplitSuggestion.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/AttributeSplitSuggestion.java new file mode 100644 index 0000000..2eeade5 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/AttributeSplitSuggestion.java @@ -0,0 +1,69 @@ +package org.apache.samoa.moa.classifiers.core; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.moa.AbstractMOAObject; +import org.apache.samoa.moa.classifiers.core.conditionaltests.InstanceConditionalTest; + +/** + * Class for computing attribute split suggestions given a split test. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public class AttributeSplitSuggestion extends AbstractMOAObject implements Comparable<AttributeSplitSuggestion> { + + private static final long serialVersionUID = 1L; + + public InstanceConditionalTest splitTest; + + public double[][] resultingClassDistributions; + + public double merit; + + public AttributeSplitSuggestion() { + } + + public AttributeSplitSuggestion(InstanceConditionalTest splitTest, + double[][] resultingClassDistributions, double merit) { + this.splitTest = splitTest; + this.resultingClassDistributions = resultingClassDistributions.clone(); + this.merit = merit; + } + + public int numSplits() { + return this.resultingClassDistributions.length; + } + + public double[] resultingClassDistributionFromSplit(int splitIndex) { + return this.resultingClassDistributions[splitIndex].clone(); + } + + @Override + public int compareTo(AttributeSplitSuggestion comp) { + return Double.compare(this.merit, comp.merit); + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // do nothing + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java new file mode 100644 index 0000000..5efdf5d --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java @@ -0,0 +1,79 @@ +package org.apache.samoa.moa.classifiers.core.attributeclassobservers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.moa.options.OptionHandler; + +/** + * Interface for observing the class data distribution for an attribute. This observer monitors the class distribution + * of a given attribute. Used in naive Bayes and decision trees to monitor data statistics on leaves. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public interface AttributeClassObserver extends OptionHandler { + + /** + * Updates statistics of this observer given an attribute value, a class and the weight of the instance observed + * + * @param attVal + * the value of the attribute + * @param classVal + * the class + * @param weight + * the weight of the instance + */ + public void observeAttributeClass(double attVal, int classVal, double weight); + + /** + * Gets the probability for an attribute value given a class + * + * @param attVal + * the attribute value + * @param classVal + * the class + * @return probability for an attribute value given a class + */ + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal); + + /** + * Gets the best split suggestion given a criterion and a class distribution + * + * @param criterion + * the split criterion to use + * @param preSplitDist + * the class distribution before the split + * @param attIndex + * the attribute index + * @param binaryOnly + * true to use binary splits + * @return suggestion of best attribute split + */ + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( + SplitCriterion criterion, double[] preSplitDist, int attIndex, + boolean binaryOnly); + + public void observeAttributeTarget(double attVal, double target); + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java new file mode 100644 index 0000000..a638318 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java @@ -0,0 +1,184 @@ +package org.apache.samoa.moa.classifiers.core.attributeclassobservers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.moa.core.ObjectRepository; +import org.apache.samoa.moa.options.AbstractOptionHandler; +import org.apache.samoa.moa.tasks.TaskMonitor; + +/** + * Class for observing the class data distribution for a numeric attribute using a binary tree. This observer monitors + * the class distribution of a given attribute. Used in naive Bayes and decision trees to monitor data statistics on + * leaves. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public class BinaryTreeNumericAttributeClassObserver extends AbstractOptionHandler + implements NumericAttributeClassObserver { + + private static final long serialVersionUID = 1L; + + public class Node implements Serializable { + + private static final long serialVersionUID = 1L; + + public double cut_point; + + public DoubleVector classCountsLeft = new DoubleVector(); + + public DoubleVector classCountsRight = new DoubleVector(); + + public Node left; + + public Node right; + + public Node(double val, int label, double weight) { + this.cut_point = val; + this.classCountsLeft.addToValue(label, weight); + } + + public void insertValue(double val, int label, double weight) { + if (val == this.cut_point) { + this.classCountsLeft.addToValue(label, weight); + } else if (val <= this.cut_point) { + this.classCountsLeft.addToValue(label, weight); + if (this.left == null) { + this.left = new Node(val, label, weight); + } else { + this.left.insertValue(val, label, weight); + } + } else { // val > cut_point + this.classCountsRight.addToValue(label, weight); + if (this.right == null) { + this.right = new Node(val, label, weight); + } else { + this.right.insertValue(val, label, weight); + } + } + } + } + + public Node root = null; + + @Override + public void observeAttributeClass(double attVal, int classVal, double weight) { + if (Double.isNaN(attVal)) { // Instance.isMissingValue(attVal) + } else { + if (this.root == null) { + this.root = new Node(attVal, classVal, weight); + } else { + this.root.insertValue(attVal, classVal, weight); + } + } + } + + @Override + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal) { + // TODO: NaiveBayes broken until implemented + return 0.0; + } + + @Override + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( + SplitCriterion criterion, double[] preSplitDist, int attIndex, + boolean binaryOnly) { + return searchForBestSplitOption(this.root, null, null, null, null, false, + criterion, preSplitDist, attIndex); + } + + protected AttributeSplitSuggestion searchForBestSplitOption( + Node currentNode, AttributeSplitSuggestion currentBestOption, + double[] actualParentLeft, + double[] parentLeft, double[] parentRight, boolean leftChild, + SplitCriterion criterion, double[] preSplitDist, int attIndex) { + if (currentNode == null) { + return currentBestOption; + } + DoubleVector leftDist = new DoubleVector(); + DoubleVector rightDist = new DoubleVector(); + if (parentLeft == null) { + leftDist.addValues(currentNode.classCountsLeft); + rightDist.addValues(currentNode.classCountsRight); + } else { + leftDist.addValues(parentLeft); + rightDist.addValues(parentRight); + if (leftChild) { + // get the exact statistics of the parent value + DoubleVector exactParentDist = new DoubleVector(); + exactParentDist.addValues(actualParentLeft); + exactParentDist.subtractValues(currentNode.classCountsLeft); + exactParentDist.subtractValues(currentNode.classCountsRight); + + // move the subtrees + leftDist.subtractValues(currentNode.classCountsRight); + rightDist.addValues(currentNode.classCountsRight); + + // move the exact value from the parent + rightDist.addValues(exactParentDist); + leftDist.subtractValues(exactParentDist); + + } else { + leftDist.addValues(currentNode.classCountsLeft); + rightDist.subtractValues(currentNode.classCountsLeft); + } + } + double[][] postSplitDists = new double[][] { leftDist.getArrayRef(), + rightDist.getArrayRef() }; + double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); + if ((currentBestOption == null) || (merit > currentBestOption.merit)) { + currentBestOption = new AttributeSplitSuggestion( + new NumericAttributeBinaryTest(attIndex, + currentNode.cut_point, true), postSplitDists, merit); + + } + currentBestOption = searchForBestSplitOption(currentNode.left, + currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], true, + criterion, preSplitDist, attIndex); + currentBestOption = searchForBestSplitOption(currentNode.right, + currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], false, + criterion, preSplitDist, attIndex); + return currentBestOption; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + // TODO Auto-generated method stub + } + + @Override + public void observeAttributeTarget(double attVal, double target) { + throw new UnsupportedOperationException("Not supported yet."); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java new file mode 100644 index 0000000..f37ac80 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java @@ -0,0 +1,149 @@ +package org.apache.samoa.moa.classifiers.core.attributeclassobservers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.moa.core.ObjectRepository; +import org.apache.samoa.moa.options.AbstractOptionHandler; +import org.apache.samoa.moa.tasks.TaskMonitor; + +/** + * Class for observing the class data distribution for a numeric attribute using a binary tree. This observer monitors + * the class distribution of a given attribute. + * + * <p> + * Learning Adaptive Model Rules from High-Speed Data Streams, ECML 2013, E. Almeida, C. Ferreira, P. Kosina and J. + * Gama; + * </p> + * + * @author E. Almeida, J. Gama + * @version $Revision: 2$ + */ +public class BinaryTreeNumericAttributeClassObserverRegression extends AbstractOptionHandler + implements NumericAttributeClassObserver { + + public static final long serialVersionUID = 1L; + + public class Node implements Serializable { + + private static final long serialVersionUID = 1L; + + public double cut_point; + + public double[] lessThan; // This array maintains statistics for the instance reaching the node with attribute values less than or iqual to the cutpoint. + + public double[] greaterThan; // This array maintains statistics for the instance reaching the node with attribute values greater than to the cutpoint. + + public Node left; + + public Node right; + + public Node(double val, double target) { + this.cut_point = val; + this.lessThan = new double[3]; + this.greaterThan = new double[3]; + this.lessThan[0] = target; // The sum of their target attribute values. + this.lessThan[1] = target * target; // The sum of the squared target attribute values. + this.lessThan[2] = 1.0; // A counter of the number of instances that have reached the node. + this.greaterThan[0] = 0.0; + this.greaterThan[1] = 0.0; + this.greaterThan[2] = 0.0; + } + + public void insertValue(double val, double target) { + if (val == this.cut_point) { + this.lessThan[0] = this.lessThan[0] + target; + this.lessThan[1] = this.lessThan[1] + (target * target); + this.lessThan[2] = this.lessThan[2] + 1; + } else if (val <= this.cut_point) { + this.lessThan[0] = this.lessThan[0] + target; + this.lessThan[1] = this.lessThan[1] + (target * target); + this.lessThan[2] = this.lessThan[2] + 1; + if (this.left == null) { + this.left = new Node(val, target); + } else { + this.left.insertValue(val, target); + } + } else { + this.greaterThan[0] = this.greaterThan[0] + target; + this.greaterThan[1] = this.greaterThan[1] + (target * target); + this.greaterThan[2] = this.greaterThan[2] + 1; + if (this.right == null) { + + this.right = new Node(val, target); + } else { + this.right.insertValue(val, target); + } + } + } + } + + public Node root1 = null; + + public void observeAttributeTarget(double attVal, double target) { + if (!Double.isNaN(attVal)) { + if (this.root1 == null) { + this.root1 = new Node(attVal, target); + } else { + this.root1.insertValue(attVal, target); + } + } + } + + @Override + public void observeAttributeClass(double attVal, int classVal, double weight) { + + } + + @Override + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal) { + return 0.0; + } + + @Override + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( + SplitCriterion criterion, double[] preSplitDist, int attIndex, + boolean binaryOnly) { + return searchForBestSplitOption(this.root1, null, null, null, null, false, + criterion, preSplitDist, attIndex); + } + + protected AttributeSplitSuggestion searchForBestSplitOption( + Node currentNode, AttributeSplitSuggestion currentBestOption, + double[] actualParentLeft, + double[] parentLeft, double[] parentRight, boolean leftChild, + SplitCriterion criterion, double[] preSplitDist, int attIndex) { + + return currentBestOption; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java new file mode 100644 index 0000000..bbf6194 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java @@ -0,0 +1,32 @@ +package org.apache.samoa.moa.classifiers.core.attributeclassobservers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +/** + * Interface for observing the class data distribution for a discrete (nominal) attribute. This observer monitors the + * class distribution of a given attribute. Used in naive Bayes and decision trees to monitor data statistics on leaves. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public interface DiscreteAttributeClassObserver extends AttributeClassObserver { + +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java new file mode 100644 index 0000000..d61d8c8 --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/FIMTDDNumericAttributeClassObserver.java @@ -0,0 +1,250 @@ +/* Project Knowledge Discovery from Data Streams, FCT LIAAD-INESC TEC, + * + * Contact: [email protected] + */ + +package org.apache.samoa.moa.classifiers.core.attributeclassobservers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.io.Serializable; + +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.moa.core.ObjectRepository; +import org.apache.samoa.moa.tasks.TaskMonitor; + +public class FIMTDDNumericAttributeClassObserver extends BinaryTreeNumericAttributeClassObserver implements + NumericAttributeClassObserver { + + private static final long serialVersionUID = 1L; + + protected class Node implements Serializable { + + private static final long serialVersionUID = 1L; + + // The split point to use + public double cut_point; + + // E-BST statistics + public DoubleVector leftStatistics = new DoubleVector(); + public DoubleVector rightStatistics = new DoubleVector(); + + // Child nodes + public Node left; + public Node right; + + public Node(double val, double label, double weight) { + this.cut_point = val; + this.leftStatistics.addToValue(0, 1); + this.leftStatistics.addToValue(1, label); + this.leftStatistics.addToValue(2, label * label); + } + + /** + * Insert a new value into the tree, updating both the sum of values and sum of squared values arrays + */ + public void insertValue(double val, double label, double weight) { + + // If the new value equals the value stored in a node, update + // the left (<=) node information + if (val == this.cut_point) { + this.leftStatistics.addToValue(0, 1); + this.leftStatistics.addToValue(1, label); + this.leftStatistics.addToValue(2, label * label); + } // If the new value is less than the value in a node, update the + // left distribution and send the value down to the left child node. + // If no left child exists, create one + else if (val <= this.cut_point) { + this.leftStatistics.addToValue(0, 1); + this.leftStatistics.addToValue(1, label); + this.leftStatistics.addToValue(2, label * label); + if (this.left == null) { + this.left = new Node(val, label, weight); + } else { + this.left.insertValue(val, label, weight); + } + } // If the new value is greater than the value in a node, update the + // right (>) distribution and send the value down to the right child node. + // If no right child exists, create one + else { // val > cut_point + this.rightStatistics.addToValue(0, 1); + this.rightStatistics.addToValue(1, label); + this.rightStatistics.addToValue(2, label * label); + if (this.right == null) { + this.right = new Node(val, label, weight); + } else { + this.right.insertValue(val, label, weight); + } + } + } + } + + // Root node of the E-BST structure for this attribute + public Node root = null; + + // Global variables for use in the FindBestSplit algorithm + double sumTotalLeft; + double sumTotalRight; + double sumSqTotalLeft; + double sumSqTotalRight; + double countRightTotal; + double countLeftTotal; + + public void observeAttributeClass(double attVal, double classVal, double weight) { + if (!Double.isNaN(attVal)) { + if (this.root == null) { + this.root = new Node(attVal, classVal, weight); + } else { + this.root.insertValue(attVal, classVal, weight); + } + } + } + + @Override + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal) { + // TODO: NaiveBayes broken until implemented + return 0.0; + } + + @Override + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion(SplitCriterion criterion, double[] preSplitDist, + int attIndex, boolean binaryOnly) { + + // Initialise global variables + sumTotalLeft = 0; + sumTotalRight = preSplitDist[1]; + sumSqTotalLeft = 0; + sumSqTotalRight = preSplitDist[2]; + countLeftTotal = 0; + countRightTotal = preSplitDist[0]; + return searchForBestSplitOption(this.root, null, criterion, attIndex); + } + + /** + * Implementation of the FindBestSplit algorithm from E.Ikonomovska et al. + */ + protected AttributeSplitSuggestion searchForBestSplitOption(Node currentNode, + AttributeSplitSuggestion currentBestOption, SplitCriterion criterion, int attIndex) { + // Return null if the current node is null or we have finished looking + // through all the possible splits + if (currentNode == null || countRightTotal == 0.0) { + return currentBestOption; + } + + if (currentNode.left != null) { + currentBestOption = searchForBestSplitOption(currentNode.left, currentBestOption, criterion, attIndex); + } + + sumTotalLeft += currentNode.leftStatistics.getValue(1); + sumTotalRight -= currentNode.leftStatistics.getValue(1); + sumSqTotalLeft += currentNode.leftStatistics.getValue(2); + sumSqTotalRight -= currentNode.leftStatistics.getValue(2); + countLeftTotal += currentNode.leftStatistics.getValue(0); + countRightTotal -= currentNode.leftStatistics.getValue(0); + + double[][] postSplitDists = new double[][] { { countLeftTotal, sumTotalLeft, sumSqTotalLeft }, + { countRightTotal, sumTotalRight, sumSqTotalRight } }; + double[] preSplitDist = new double[] { (countLeftTotal + countRightTotal), (sumTotalLeft + sumTotalRight), + (sumSqTotalLeft + sumSqTotalRight) }; + double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); + + if ((currentBestOption == null) || (merit > currentBestOption.merit)) { + currentBestOption = new AttributeSplitSuggestion( + new NumericAttributeBinaryTest(attIndex, + currentNode.cut_point, true), postSplitDists, merit); + + } + + if (currentNode.right != null) { + currentBestOption = searchForBestSplitOption(currentNode.right, currentBestOption, criterion, attIndex); + } + sumTotalLeft -= currentNode.leftStatistics.getValue(1); + sumTotalRight += currentNode.leftStatistics.getValue(1); + sumSqTotalLeft -= currentNode.leftStatistics.getValue(2); + sumSqTotalRight += currentNode.leftStatistics.getValue(2); + countLeftTotal -= currentNode.leftStatistics.getValue(0); + countRightTotal += currentNode.leftStatistics.getValue(0); + + return currentBestOption; + } + + /** + * A method to remove all nodes in the E-BST in which it and all it's children represent 'bad' split points + */ + public void removeBadSplits(SplitCriterion criterion, double lastCheckRatio, double lastCheckSDR, double lastCheckE) { + removeBadSplitNodes(criterion, this.root, lastCheckRatio, lastCheckSDR, lastCheckE); + } + + /** + * Recursive method that first checks all of a node's children before deciding if it is 'bad' and may be removed + */ + private boolean removeBadSplitNodes(SplitCriterion criterion, Node currentNode, double lastCheckRatio, + double lastCheckSDR, double lastCheckE) { + boolean isBad = false; + + if (currentNode == null) { + return true; + } + + if (currentNode.left != null) { + isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE); + } + + if (currentNode.right != null && isBad) { + isBad = removeBadSplitNodes(criterion, currentNode.left, lastCheckRatio, lastCheckSDR, lastCheckE); + } + + if (isBad) { + + double[][] postSplitDists = new double[][] { + { currentNode.leftStatistics.getValue(0), currentNode.leftStatistics.getValue(1), + currentNode.leftStatistics.getValue(2) }, + { currentNode.rightStatistics.getValue(0), currentNode.rightStatistics.getValue(1), + currentNode.rightStatistics.getValue(2) } }; + double[] preSplitDist = new double[] { + (currentNode.leftStatistics.getValue(0) + currentNode.rightStatistics.getValue(0)), + (currentNode.leftStatistics.getValue(1) + currentNode.rightStatistics.getValue(1)), + (currentNode.leftStatistics.getValue(2) + currentNode.rightStatistics.getValue(2)) }; + double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); + + if ((merit / lastCheckSDR) < (lastCheckRatio - (2 * lastCheckE))) { + currentNode = null; + return true; + } + } + + return false; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + // TODO Auto-generated method stub + } +} http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/9b178f63/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java new file mode 100644 index 0000000..0bd25da --- /dev/null +++ b/samoa-api/src/main/java/org/apache/samoa/moa/classifiers/core/attributeclassobservers/GaussianNumericAttributeClassObserver.java @@ -0,0 +1,182 @@ +package org.apache.samoa.moa.classifiers.core.attributeclassobservers; + +/* + * #%L + * SAMOA + * %% + * Copyright (C) 2014 - 2015 Apache Software Foundation + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ + +import java.util.Set; +import java.util.TreeSet; + +import org.apache.samoa.moa.classifiers.core.AttributeSplitSuggestion; +import org.apache.samoa.moa.classifiers.core.conditionaltests.NumericAttributeBinaryTest; +import org.apache.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; +import org.apache.samoa.moa.core.AutoExpandVector; +import org.apache.samoa.moa.core.DoubleVector; +import org.apache.samoa.moa.core.GaussianEstimator; +import org.apache.samoa.moa.core.ObjectRepository; +import org.apache.samoa.moa.core.Utils; +import org.apache.samoa.moa.options.AbstractOptionHandler; +import org.apache.samoa.moa.tasks.TaskMonitor; + +import com.github.javacliparser.IntOption; + +/** + * Class for observing the class data distribution for a numeric attribute using gaussian estimators. This observer + * monitors the class distribution of a given attribute. Used in naive Bayes and decision trees to monitor data + * statistics on leaves. + * + * @author Richard Kirkby ([email protected]) + * @version $Revision: 7 $ + */ +public class GaussianNumericAttributeClassObserver extends AbstractOptionHandler + implements NumericAttributeClassObserver { + + private static final long serialVersionUID = 1L; + + protected DoubleVector minValueObservedPerClass = new DoubleVector(); + + protected DoubleVector maxValueObservedPerClass = new DoubleVector(); + + protected AutoExpandVector<GaussianEstimator> attValDistPerClass = new AutoExpandVector<>(); + + /** + * @param classVal + * @return The requested Estimator if it exists, or null if not present. + */ + public GaussianEstimator getEstimator(int classVal) { + return this.attValDistPerClass.get(classVal); + } + + public IntOption numBinsOption = new IntOption("numBins", 'n', + "The number of bins.", 10, 1, Integer.MAX_VALUE); + + @Override + public void observeAttributeClass(double attVal, int classVal, double weight) { + if (!Utils.isMissingValue(attVal)) { + GaussianEstimator valDist = this.attValDistPerClass.get(classVal); + if (valDist == null) { + valDist = new GaussianEstimator(); + this.attValDistPerClass.set(classVal, valDist); + this.minValueObservedPerClass.setValue(classVal, attVal); + this.maxValueObservedPerClass.setValue(classVal, attVal); + } else { + if (attVal < this.minValueObservedPerClass.getValue(classVal)) { + this.minValueObservedPerClass.setValue(classVal, attVal); + } + if (attVal > this.maxValueObservedPerClass.getValue(classVal)) { + this.maxValueObservedPerClass.setValue(classVal, attVal); + } + } + valDist.addObservation(attVal, weight); + } + } + + @Override + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal) { + GaussianEstimator obs = this.attValDistPerClass.get(classVal); + return obs != null ? obs.probabilityDensity(attVal) : 0.0; + } + + @Override + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( + SplitCriterion criterion, double[] preSplitDist, int attIndex, + boolean binaryOnly) { + AttributeSplitSuggestion bestSuggestion = null; + double[] suggestedSplitValues = getSplitPointSuggestions(); + for (double splitValue : suggestedSplitValues) { + double[][] postSplitDists = getClassDistsResultingFromBinarySplit(splitValue); + double merit = criterion.getMeritOfSplit(preSplitDist, + postSplitDists); + if ((bestSuggestion == null) || (merit > bestSuggestion.merit)) { + bestSuggestion = new AttributeSplitSuggestion( + new NumericAttributeBinaryTest(attIndex, splitValue, + true), postSplitDists, merit); + } + } + return bestSuggestion; + } + + public double[] getSplitPointSuggestions() { + Set<Double> suggestedSplitValues = new TreeSet<>(); + double minValue = Double.POSITIVE_INFINITY; + double maxValue = Double.NEGATIVE_INFINITY; + for (int i = 0; i < this.attValDistPerClass.size(); i++) { + GaussianEstimator estimator = this.attValDistPerClass.get(i); + if (estimator != null) { + if (this.minValueObservedPerClass.getValue(i) < minValue) { + minValue = this.minValueObservedPerClass.getValue(i); + } + if (this.maxValueObservedPerClass.getValue(i) > maxValue) { + maxValue = this.maxValueObservedPerClass.getValue(i); + } + } + } + if (minValue < Double.POSITIVE_INFINITY) { + double range = maxValue - minValue; + for (int i = 0; i < this.numBinsOption.getValue(); i++) { + double splitValue = range / (this.numBinsOption.getValue() + 1.0) * (i + 1) + + minValue; + if ((splitValue > minValue) && (splitValue < maxValue)) { + suggestedSplitValues.add(splitValue); + } + } + } + double[] suggestions = new double[suggestedSplitValues.size()]; + int i = 0; + for (double suggestion : suggestedSplitValues) { + suggestions[i++] = suggestion; + } + return suggestions; + } + + // assume all values equal to splitValue go to lhs + public double[][] getClassDistsResultingFromBinarySplit(double splitValue) { + DoubleVector lhsDist = new DoubleVector(); + DoubleVector rhsDist = new DoubleVector(); + for (int i = 0; i < this.attValDistPerClass.size(); i++) { + GaussianEstimator estimator = this.attValDistPerClass.get(i); + if (estimator != null) { + if (splitValue < this.minValueObservedPerClass.getValue(i)) { + rhsDist.addToValue(i, estimator.getTotalWeightObserved()); + } else if (splitValue >= this.maxValueObservedPerClass.getValue(i)) { + lhsDist.addToValue(i, estimator.getTotalWeightObserved()); + } else { + double[] weightDist = estimator.estimatedWeight_LessThan_EqualTo_GreaterThan_Value(splitValue); + lhsDist.addToValue(i, weightDist[0] + weightDist[1]); + rhsDist.addToValue(i, weightDist[2]); + } + } + } + return new double[][] { lhsDist.getArrayRef(), rhsDist.getArrayRef() }; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + } + + @Override + public void observeAttributeTarget(double attVal, double target) { + throw new UnsupportedOperationException("Not supported yet."); + } +}
