http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java index 894a0cc..ae8684b 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/SingleLearner.java @@ -42,56 +42,59 @@ import com.yahoo.labs.samoa.topology.TopologyBuilder; */ public final class SingleLearner implements Learner, Configurable { - private static final long serialVersionUID = 684111382631697031L; - - private LocalClustererProcessor learnerP; - - private Stream resultStream; - - private Instances dataset; - - public ClassOption learnerOption = new ClassOption("learner", 'l', - "Learner to train.", LocalClustererAdapter.class, ClustreamClustererAdapter.class.getName()); - - private TopologyBuilder builder; - - private int parallelism; - - @Override - public void init(TopologyBuilder builder, Instances dataset, int parallelism){ - this.builder = builder; - this.dataset = dataset; - this.parallelism = parallelism; - this.setLayout(); - } - - - protected void setLayout() { - learnerP = new LocalClustererProcessor(); - LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); - learner.setDataset(this.dataset); - learnerP.setLearner(learner); - - this.builder.addProcessor(learnerP, this.parallelism); - resultStream = this.builder.createStream(learnerP); - - learnerP.setOutputStream(resultStream); - } - - /* (non-Javadoc) - * @see samoa.classifiers.Classifier#getInputProcessingItem() - */ - @Override - public Processor getInputProcessor() { - return learnerP; - } - - /* (non-Javadoc) - * @see samoa.learners.Learner#getResultStreams() - */ - @Override - public Set<Stream> getResultStreams() { - Set<Stream> streams = ImmutableSet.of(this.resultStream); - return streams; - } + private static final long serialVersionUID = 684111382631697031L; + + private LocalClustererProcessor learnerP; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', + "Learner to train.", LocalClustererAdapter.class, ClustreamClustererAdapter.class.getName()); + + private TopologyBuilder builder; + + private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + + this.builder.addProcessor(learnerP, this.parallelism); + resultStream = this.builder.createStream(learnerP); + + learnerP.setOutputStream(resultStream); + } + + /* + * (non-Javadoc) + * + * @see samoa.classifiers.Classifier#getInputProcessingItem() + */ + @Override + public Processor getInputProcessor() { + return learnerP; + } + + /* + * (non-Javadoc) + * + * @see samoa.learners.Learner#getResultStreams() + */ + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.resultStream); + return streams; + } }
http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java index e75a1bd..7d266f5 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/ClusteringDistributorProcessor.java @@ -34,65 +34,67 @@ import com.yahoo.labs.samoa.topology.Stream; */ public class ClusteringDistributorProcessor implements Processor { - private static final long serialVersionUID = -1550901409625192730L; + private static final long serialVersionUID = -1550901409625192730L; - private Stream outputStream; - private Stream evaluationStream; - private int numInstances; + private Stream outputStream; + private Stream evaluationStream; + private int numInstances; - public Stream getOutputStream() { - return outputStream; - } + public Stream getOutputStream() { + return outputStream; + } - public void setOutputStream(Stream outputStream) { - this.outputStream = outputStream; - } + public void setOutputStream(Stream outputStream) { + this.outputStream = outputStream; + } - public Stream getEvaluationStream() { - return evaluationStream; - } + public Stream getEvaluationStream() { + return evaluationStream; + } - public void setEvaluationStream(Stream evaluationStream) { - this.evaluationStream = evaluationStream; - } + public void setEvaluationStream(Stream evaluationStream) { + this.evaluationStream = evaluationStream; + } - /** - * Process event. - * - * @param event - * the event - * @return true, if successful - */ - public boolean process(ContentEvent event) { - // distinguish between ClusteringContentEvent and ClusteringEvaluationContentEvent - if (event instanceof ClusteringContentEvent) { - ClusteringContentEvent cce = (ClusteringContentEvent) event; - outputStream.put(event); - if (cce.isSample()) { - evaluationStream.put(new ClusteringEvaluationContentEvent(null, new DataPoint(cce.getInstance(), numInstances++), cce.isLastEvent())); - } - } else if (event instanceof ClusteringEvaluationContentEvent) { - evaluationStream.put(event); - } - return true; + /** + * Process event. + * + * @param event + * the event + * @return true, if successful + */ + public boolean process(ContentEvent event) { + // distinguish between ClusteringContentEvent and + // ClusteringEvaluationContentEvent + if (event instanceof ClusteringContentEvent) { + ClusteringContentEvent cce = (ClusteringContentEvent) event; + outputStream.put(event); + if (cce.isSample()) { + evaluationStream.put(new ClusteringEvaluationContentEvent(null, + new DataPoint(cce.getInstance(), numInstances++), cce.isLastEvent())); + } + } else if (event instanceof ClusteringEvaluationContentEvent) { + evaluationStream.put(event); } + return true; + } - /* - * (non-Javadoc) - * - * @see samoa.core.Processor#newProcessor(samoa.core.Processor) - */ - @Override - public Processor newProcessor(Processor sourceProcessor) { - ClusteringDistributorProcessor newProcessor = new ClusteringDistributorProcessor(); - ClusteringDistributorProcessor originProcessor = (ClusteringDistributorProcessor) sourceProcessor; - if (originProcessor.getOutputStream() != null) - newProcessor.setOutputStream(originProcessor.getOutputStream()); - if (originProcessor.getEvaluationStream() != null) - newProcessor.setEvaluationStream(originProcessor.getEvaluationStream()); - return newProcessor; - } + /* + * (non-Javadoc) + * + * @see samoa.core.Processor#newProcessor(samoa.core.Processor) + */ + @Override + public Processor newProcessor(Processor sourceProcessor) { + ClusteringDistributorProcessor newProcessor = new ClusteringDistributorProcessor(); + ClusteringDistributorProcessor originProcessor = (ClusteringDistributorProcessor) sourceProcessor; + if (originProcessor.getOutputStream() != null) + newProcessor.setOutputStream(originProcessor.getOutputStream()); + if (originProcessor.getEvaluationStream() != null) + newProcessor.setEvaluationStream(originProcessor.getEvaluationStream()); + return newProcessor; + } - public void onCreate(int id) { - } + public void onCreate(int id) { + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java index d924733..edfecfa 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/learners/clusterers/simple/DistributedClusterer.java @@ -45,74 +45,75 @@ import com.yahoo.labs.samoa.topology.TopologyBuilder; */ public final class DistributedClusterer implements Learner, Configurable { - private static final long serialVersionUID = 684111382631697031L; - - private Stream resultStream; - - private Instances dataset; - - public ClassOption learnerOption = new ClassOption("learner", 'l', "Clusterer to use.", LocalClustererAdapter.class, - ClustreamClustererAdapter.class.getName()); - - public IntOption paralellismOption = new IntOption("paralellismOption", 'P', "The paralellism level for concurrent processes", 2, 1, Integer.MAX_VALUE); - - private TopologyBuilder builder; - -// private ClusteringDistributorProcessor distributorP; - private LocalClustererProcessor learnerP; - -// private Stream distributorToLocalStream; - private Stream localToGlobalStream; - -// private int parallelism; - - @Override - public void init(TopologyBuilder builder, Instances dataset, int parallelism) { - this.builder = builder; - this.dataset = dataset; -// this.parallelism = parallelism; - this.setLayout(); - } - - protected void setLayout() { - // Distributor -// distributorP = new ClusteringDistributorProcessor(); -// this.builder.addProcessor(distributorP, parallelism); -// distributorToLocalStream = this.builder.createStream(distributorP); -// distributorP.setOutputStream(distributorToLocalStream); -// distributorToGlobalStream = this.builder.createStream(distributorP); - - // Local Clustering - learnerP = new LocalClustererProcessor(); - LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); - learner.setDataset(this.dataset); - learnerP.setLearner(learner); - builder.addProcessor(learnerP, this.paralellismOption.getValue()); - localToGlobalStream = this.builder.createStream(learnerP); - learnerP.setOutputStream(localToGlobalStream); - - // Global Clustering - LocalClustererProcessor globalClusteringCombinerP = new LocalClustererProcessor(); - LocalClustererAdapter globalLearner = (LocalClustererAdapter) this.learnerOption.getValue(); - globalLearner.setDataset(this.dataset); - globalClusteringCombinerP.setLearner(learner); - builder.addProcessor(globalClusteringCombinerP, 1); - builder.connectInputAllStream(localToGlobalStream, globalClusteringCombinerP); - - // Output Stream - resultStream = this.builder.createStream(globalClusteringCombinerP); - globalClusteringCombinerP.setOutputStream(resultStream); - } - - @Override - public Processor getInputProcessor() { -// return distributorP; - return learnerP; - } - - @Override - public Set<Stream> getResultStreams() { - Set<Stream> streams = ImmutableSet.of(this.resultStream); - return streams; - } + private static final long serialVersionUID = 684111382631697031L; + + private Stream resultStream; + + private Instances dataset; + + public ClassOption learnerOption = new ClassOption("learner", 'l', "Clusterer to use.", LocalClustererAdapter.class, + ClustreamClustererAdapter.class.getName()); + + public IntOption paralellismOption = new IntOption("paralellismOption", 'P', + "The paralellism level for concurrent processes", 2, 1, Integer.MAX_VALUE); + + private TopologyBuilder builder; + + // private ClusteringDistributorProcessor distributorP; + private LocalClustererProcessor learnerP; + + // private Stream distributorToLocalStream; + private Stream localToGlobalStream; + + // private int parallelism; + + @Override + public void init(TopologyBuilder builder, Instances dataset, int parallelism) { + this.builder = builder; + this.dataset = dataset; + // this.parallelism = parallelism; + this.setLayout(); + } + + protected void setLayout() { + // Distributor + // distributorP = new ClusteringDistributorProcessor(); + // this.builder.addProcessor(distributorP, parallelism); + // distributorToLocalStream = this.builder.createStream(distributorP); + // distributorP.setOutputStream(distributorToLocalStream); + // distributorToGlobalStream = this.builder.createStream(distributorP); + + // Local Clustering + learnerP = new LocalClustererProcessor(); + LocalClustererAdapter learner = (LocalClustererAdapter) this.learnerOption.getValue(); + learner.setDataset(this.dataset); + learnerP.setLearner(learner); + builder.addProcessor(learnerP, this.paralellismOption.getValue()); + localToGlobalStream = this.builder.createStream(learnerP); + learnerP.setOutputStream(localToGlobalStream); + + // Global Clustering + LocalClustererProcessor globalClusteringCombinerP = new LocalClustererProcessor(); + LocalClustererAdapter globalLearner = (LocalClustererAdapter) this.learnerOption.getValue(); + globalLearner.setDataset(this.dataset); + globalClusteringCombinerP.setLearner(learner); + builder.addProcessor(globalClusteringCombinerP, 1); + builder.connectInputAllStream(localToGlobalStream, globalClusteringCombinerP); + + // Output Stream + resultStream = this.builder.createStream(globalClusteringCombinerP); + globalClusteringCombinerP.setOutputStream(resultStream); + } + + @Override + public Processor getInputProcessor() { + // return distributorP; + return learnerP; + } + + @Override + public Set<Stream> getResultStreams() { + Set<Stream> streams = ImmutableSet.of(this.resultStream); + return streams; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java index 37303ec..45fa228 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/AbstractMOAObject.java @@ -21,60 +21,63 @@ package com.yahoo.labs.samoa.moa; */ import com.yahoo.labs.samoa.moa.core.SerializeUtils; + //import moa.core.SizeOf; /** - * Abstract MOA Object. All classes that are serializable, copiable, - * can measure its size, and can give a description, extend this class. - * + * Abstract MOA Object. All classes that are serializable, copiable, can measure + * its size, and can give a description, extend this class. + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public abstract class AbstractMOAObject implements MOAObject { - @Override - public MOAObject copy() { - return copy(this); - } + @Override + public MOAObject copy() { + return copy(this); + } - @Override - public int measureByteSize() { - return measureByteSize(this); - } + @Override + public int measureByteSize() { + return measureByteSize(this); + } - /** - * Returns a description of the object. - * - * @return a description of the object - */ - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - getDescription(sb, 0); - return sb.toString(); - } + /** + * Returns a description of the object. + * + * @return a description of the object + */ + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + getDescription(sb, 0); + return sb.toString(); + } - /** - * This method produces a copy of an object. - * - * @param obj object to copy - * @return a copy of the object - */ - public static MOAObject copy(MOAObject obj) { - try { - return (MOAObject) SerializeUtils.copyObject(obj); - } catch (Exception e) { - throw new RuntimeException("Object copy failed.", e); - } + /** + * This method produces a copy of an object. + * + * @param obj + * object to copy + * @return a copy of the object + */ + public static MOAObject copy(MOAObject obj) { + try { + return (MOAObject) SerializeUtils.copyObject(obj); + } catch (Exception e) { + throw new RuntimeException("Object copy failed.", e); } + } - /** - * Gets the memory size of an object. - * - * @param obj object to measure the memory size - * @return the memory size of this object - */ - public static int measureByteSize(MOAObject obj) { - return 0; //(int) SizeOf.fullSizeOf(obj); - } + /** + * Gets the memory size of an object. + * + * @param obj + * object to measure the memory size + * @return the memory size of this object + */ + public static int measureByteSize(MOAObject obj) { + return 0; // (int) SizeOf.fullSizeOf(obj); + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java index cc26eaa..cd98892 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/MOAObject.java @@ -23,36 +23,38 @@ package com.yahoo.labs.samoa.moa; import java.io.Serializable; /** - * Interface implemented by classes in MOA, so that all are serializable, - * can produce copies of their objects, and can measure its memory size. - * They also give a string description. - * + * Interface implemented by classes in MOA, so that all are serializable, can + * produce copies of their objects, and can measure its memory size. They also + * give a string description. + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public interface MOAObject extends Serializable { - /** - * Gets the memory size of this object. - * - * @return the memory size of this object - */ - public int measureByteSize(); + /** + * Gets the memory size of this object. + * + * @return the memory size of this object + */ + public int measureByteSize(); - /** - * This method produces a copy of this object. - * - * @return a copy of this object - */ - public MOAObject copy(); + /** + * This method produces a copy of this object. + * + * @return a copy of this object + */ + public MOAObject copy(); - /** - * Returns a string representation of this object. - * Used in <code>AbstractMOAObject.toString</code> - * to give a string representation of the object. - * - * @param sb the stringbuilder to add the description - * @param indent the number of characters to indent - */ - public void getDescription(StringBuilder sb, int indent); + /** + * Returns a string representation of this object. Used in + * <code>AbstractMOAObject.toString</code> to give a string representation of + * the object. + * + * @param sb + * the stringbuilder to add the description + * @param indent + * the number of characters to indent + */ + public void getDescription(StringBuilder sb, int indent); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java index 09de49e..21bbf4b 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/AbstractClassifier.java @@ -1,4 +1,3 @@ - package com.yahoo.labs.samoa.moa.classifiers; /* @@ -41,338 +40,347 @@ import com.yahoo.labs.samoa.moa.tasks.TaskMonitor; public abstract class AbstractClassifier extends AbstractOptionHandler implements Classifier { - @Override - public String getPurposeString() { - return "MOA Classifier: " + getClass().getCanonicalName(); - } + @Override + public String getPurposeString() { + return "MOA Classifier: " + getClass().getCanonicalName(); + } - /** Header of the instances of the data stream */ - protected InstancesHeader modelContext; + /** Header of the instances of the data stream */ + protected InstancesHeader modelContext; - /** Sum of the weights of the instances trained by this model */ - protected double trainingWeightSeenByModel = 0.0; + /** Sum of the weights of the instances trained by this model */ + protected double trainingWeightSeenByModel = 0.0; - /** Random seed used in randomizable learners */ - protected int randomSeed = 1; + /** Random seed used in randomizable learners */ + protected int randomSeed = 1; - /** Option for randomizable learners to change the random seed */ - protected IntOption randomSeedOption; + /** Option for randomizable learners to change the random seed */ + protected IntOption randomSeedOption; - /** Random Generator used in randomizable learners */ - public Random classifierRandom; + /** Random Generator used in randomizable learners */ + public Random classifierRandom; - /** - * Creates an classifier and setups the random seed option - * if the classifier is randomizable. - */ - public AbstractClassifier() { - if (isRandomizable()) { - this.randomSeedOption = new IntOption("randomSeed", 'r', - "Seed for random behaviour of the classifier.", 1); - } + /** + * Creates an classifier and setups the random seed option if the classifier + * is randomizable. + */ + public AbstractClassifier() { + if (isRandomizable()) { + this.randomSeedOption = new IntOption("randomSeed", 'r', + "Seed for random behaviour of the classifier.", 1); } + } - @Override - public void prepareForUseImpl(TaskMonitor monitor, - ObjectRepository repository) { - if (this.randomSeedOption != null) { - this.randomSeed = this.randomSeedOption.getValue(); - } - if (!trainingHasStarted()) { - resetLearning(); - } - } - - - @Override - public double[] getVotesForInstance(Example<Instance> example){ - return getVotesForInstance(example.getData()); - } - - @Override - public abstract double[] getVotesForInstance(Instance inst); - - @Override - public void setModelContext(InstancesHeader ih) { - if ((ih != null) && (ih.classIndex() < 0)) { - throw new IllegalArgumentException( - "Context for a classifier must include a class to learn"); - } - if (trainingHasStarted() - && (this.modelContext != null) - && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) { - throw new IllegalArgumentException( - "New context is not compatible with existing model"); - } - this.modelContext = ih; + @Override + public void prepareForUseImpl(TaskMonitor monitor, + ObjectRepository repository) { + if (this.randomSeedOption != null) { + this.randomSeed = this.randomSeedOption.getValue(); } - - @Override - public InstancesHeader getModelContext() { - return this.modelContext; - } - - @Override - public void setRandomSeed(int s) { - this.randomSeed = s; - if (this.randomSeedOption != null) { - // keep option consistent - this.randomSeedOption.setValue(s); - } - } - - @Override - public boolean trainingHasStarted() { - return this.trainingWeightSeenByModel > 0.0; - } - - @Override - public double trainingWeightSeenByModel() { - return this.trainingWeightSeenByModel; - } - - @Override - public void resetLearning() { - this.trainingWeightSeenByModel = 0.0; - if (isRandomizable()) { - this.classifierRandom = new Random(this.randomSeed); - } - resetLearningImpl(); + if (!trainingHasStarted()) { + resetLearning(); } + } - @Override - public void trainOnInstance(Instance inst) { - if (inst.weight() > 0.0) { - this.trainingWeightSeenByModel += inst.weight(); - trainOnInstanceImpl(inst); - } - } + @Override + public double[] getVotesForInstance(Example<Instance> example) { + return getVotesForInstance(example.getData()); + } - @Override - public Measurement[] getModelMeasurements() { - List<Measurement> measurementList = new LinkedList<>(); - measurementList.add(new Measurement("model training instances", - trainingWeightSeenByModel())); - measurementList.add(new Measurement("model serialized size (bytes)", - measureByteSize())); - Measurement[] modelMeasurements = getModelMeasurementsImpl(); - if (modelMeasurements != null) { - measurementList.addAll(Arrays.asList(modelMeasurements)); - } - // add average of sub-model measurements - Learner[] subModels = getSublearners(); - if ((subModels != null) && (subModels.length > 0)) { - List<Measurement[]> subMeasurements = new LinkedList<>(); - for (Learner subModel : subModels) { - if (subModel != null) { - subMeasurements.add(subModel.getModelMeasurements()); - } - } - Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements.toArray(new Measurement[subMeasurements.size()][])); - measurementList.addAll(Arrays.asList(avgMeasurements)); - } - return measurementList.toArray(new Measurement[measurementList.size()]); - } + @Override + public abstract double[] getVotesForInstance(Instance inst); - @Override - public void getDescription(StringBuilder out, int indent) { - StringUtils.appendIndented(out, indent, "Model type: "); - out.append(this.getClass().getName()); - StringUtils.appendNewline(out); - Measurement.getMeasurementsDescription(getModelMeasurements(), out, - indent); - StringUtils.appendNewlineIndented(out, indent, "Model description:"); - StringUtils.appendNewline(out); - if (trainingHasStarted()) { - getModelDescription(out, indent); - } else { - StringUtils.appendIndented(out, indent, - "Model has not been trained."); - } + @Override + public void setModelContext(InstancesHeader ih) { + if ((ih != null) && (ih.classIndex() < 0)) { + throw new IllegalArgumentException( + "Context for a classifier must include a class to learn"); } - - @Override - public Learner[] getSublearners() { - return null; + if (trainingHasStarted() + && (this.modelContext != null) + && ((ih == null) || !contextIsCompatible(this.modelContext, ih))) { + throw new IllegalArgumentException( + "New context is not compatible with existing model"); } - - - @Override - public Classifier[] getSubClassifiers() { - return null; + this.modelContext = ih; + } + + @Override + public InstancesHeader getModelContext() { + return this.modelContext; + } + + @Override + public void setRandomSeed(int s) { + this.randomSeed = s; + if (this.randomSeedOption != null) { + // keep option consistent + this.randomSeedOption.setValue(s); } - - - @Override - public Classifier copy() { - return (Classifier) super.copy(); + } + + @Override + public boolean trainingHasStarted() { + return this.trainingWeightSeenByModel > 0.0; + } + + @Override + public double trainingWeightSeenByModel() { + return this.trainingWeightSeenByModel; + } + + @Override + public void resetLearning() { + this.trainingWeightSeenByModel = 0.0; + if (isRandomizable()) { + this.classifierRandom = new Random(this.randomSeed); } - - - @Override - public MOAObject getModel(){ - return this; + resetLearningImpl(); + } + + @Override + public void trainOnInstance(Instance inst) { + if (inst.weight() > 0.0) { + this.trainingWeightSeenByModel += inst.weight(); + trainOnInstanceImpl(inst); } - - @Override - public void trainOnInstance(Example<Instance> example){ - trainOnInstance(example.getData()); - } - - @Override - public boolean correctlyClassifies(Instance inst) { - return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue(); + } + + @Override + public Measurement[] getModelMeasurements() { + List<Measurement> measurementList = new LinkedList<>(); + measurementList.add(new Measurement("model training instances", + trainingWeightSeenByModel())); + measurementList.add(new Measurement("model serialized size (bytes)", + measureByteSize())); + Measurement[] modelMeasurements = getModelMeasurementsImpl(); + if (modelMeasurements != null) { + measurementList.addAll(Arrays.asList(modelMeasurements)); } - - /** - * Gets the name of the attribute of the class from the header. - * - * @return the string with name of the attribute of the class - */ - public String getClassNameString() { - return InstancesHeader.getClassNameString(this.modelContext); + // add average of sub-model measurements + Learner[] subModels = getSublearners(); + if ((subModels != null) && (subModels.length > 0)) { + List<Measurement[]> subMeasurements = new LinkedList<>(); + for (Learner subModel : subModels) { + if (subModel != null) { + subMeasurements.add(subModel.getModelMeasurements()); + } + } + Measurement[] avgMeasurements = Measurement.averageMeasurements(subMeasurements + .toArray(new Measurement[subMeasurements.size()][])); + measurementList.addAll(Arrays.asList(avgMeasurements)); } - - /** - * Gets the name of a label of the class from the header. - * - * @param classLabelIndex the label index - * @return the name of the label of the class - */ - public String getClassLabelString(int classLabelIndex) { - return InstancesHeader.getClassLabelString(this.modelContext, - classLabelIndex); + return measurementList.toArray(new Measurement[measurementList.size()]); + } + + @Override + public void getDescription(StringBuilder out, int indent) { + StringUtils.appendIndented(out, indent, "Model type: "); + out.append(this.getClass().getName()); + StringUtils.appendNewline(out); + Measurement.getMeasurementsDescription(getModelMeasurements(), out, + indent); + StringUtils.appendNewlineIndented(out, indent, "Model description:"); + StringUtils.appendNewline(out); + if (trainingHasStarted()) { + getModelDescription(out, indent); + } else { + StringUtils.appendIndented(out, indent, + "Model has not been trained."); } - - /** - * Gets the name of an attribute from the header. - * - * @param attIndex the attribute index - * @return the name of the attribute - */ - public String getAttributeNameString(int attIndex) { - return InstancesHeader.getAttributeNameString(this.modelContext, attIndex); + } + + @Override + public Learner[] getSublearners() { + return null; + } + + @Override + public Classifier[] getSubClassifiers() { + return null; + } + + @Override + public Classifier copy() { + return (Classifier) super.copy(); + } + + @Override + public MOAObject getModel() { + return this; + } + + @Override + public void trainOnInstance(Example<Instance> example) { + trainOnInstance(example.getData()); + } + + @Override + public boolean correctlyClassifies(Instance inst) { + return Utils.maxIndex(getVotesForInstance(inst)) == (int) inst.classValue(); + } + + /** + * Gets the name of the attribute of the class from the header. + * + * @return the string with name of the attribute of the class + */ + public String getClassNameString() { + return InstancesHeader.getClassNameString(this.modelContext); + } + + /** + * Gets the name of a label of the class from the header. + * + * @param classLabelIndex + * the label index + * @return the name of the label of the class + */ + public String getClassLabelString(int classLabelIndex) { + return InstancesHeader.getClassLabelString(this.modelContext, + classLabelIndex); + } + + /** + * Gets the name of an attribute from the header. + * + * @param attIndex + * the attribute index + * @return the name of the attribute + */ + public String getAttributeNameString(int attIndex) { + return InstancesHeader.getAttributeNameString(this.modelContext, attIndex); + } + + /** + * Gets the name of a value of an attribute from the header. + * + * @param attIndex + * the attribute index + * @param valIndex + * the value of the attribute + * @return the name of the value of the attribute + */ + public String getNominalValueString(int attIndex, int valIndex) { + return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex); + } + + /** + * Returns if two contexts or headers of instances are compatible.<br> + * <br> + * + * Two contexts are compatible if they follow the following rules:<br> + * Rule 1: num classes can increase but never decrease<br> + * Rule 2: num attributes can increase but never decrease<br> + * Rule 3: num nominal attribute values can increase but never decrease<br> + * Rule 4: attribute types must stay in the same order (although class can + * move; is always skipped over)<br> + * <br> + * + * Attribute names are free to change, but should always still represent the + * original attributes. + * + * @param originalContext + * the first context to compare + * @param newContext + * the second context to compare + * @return true if the two contexts are compatible. + */ + public static boolean contextIsCompatible(InstancesHeader originalContext, + InstancesHeader newContext) { + + if (newContext.numClasses() < originalContext.numClasses()) { + return false; // rule 1 } - - /** - * Gets the name of a value of an attribute from the header. - * - * @param attIndex the attribute index - * @param valIndex the value of the attribute - * @return the name of the value of the attribute - */ - public String getNominalValueString(int attIndex, int valIndex) { - return InstancesHeader.getNominalValueString(this.modelContext, attIndex, valIndex); + if (newContext.numAttributes() < originalContext.numAttributes()) { + return false; // rule 2 } - - - /** - * Returns if two contexts or headers of instances are compatible.<br><br> - * - * Two contexts are compatible if they follow the following rules:<br> - * Rule 1: num classes can increase but never decrease<br> - * Rule 2: num attributes can increase but never decrease<br> - * Rule 3: num nominal attribute values can increase but never decrease<br> - * Rule 4: attribute types must stay in the same order (although class - * can move; is always skipped over)<br><br> - * - * Attribute names are free to change, but should always still represent - * the original attributes. - * - * @param originalContext the first context to compare - * @param newContext the second context to compare - * @return true if the two contexts are compatible. - */ - public static boolean contextIsCompatible(InstancesHeader originalContext, - InstancesHeader newContext) { - - if (newContext.numClasses() < originalContext.numClasses()) { - return false; // rule 1 + int oPos = 0; + int nPos = 0; + while (oPos < originalContext.numAttributes()) { + if (oPos == originalContext.classIndex()) { + oPos++; + if (!(oPos < originalContext.numAttributes())) { + break; } - if (newContext.numAttributes() < originalContext.numAttributes()) { - return false; // rule 2 + } + if (nPos == newContext.classIndex()) { + nPos++; + } + if (originalContext.attribute(oPos).isNominal()) { + if (!newContext.attribute(nPos).isNominal()) { + return false; // rule 4 } - int oPos = 0; - int nPos = 0; - while (oPos < originalContext.numAttributes()) { - if (oPos == originalContext.classIndex()) { - oPos++; - if (!(oPos < originalContext.numAttributes())) { - break; - } - } - if (nPos == newContext.classIndex()) { - nPos++; - } - if (originalContext.attribute(oPos).isNominal()) { - if (!newContext.attribute(nPos).isNominal()) { - return false; // rule 4 - } - if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) { - return false; // rule 3 - } - } else { - assert (originalContext.attribute(oPos).isNumeric()); - if (!newContext.attribute(nPos).isNumeric()) { - return false; // rule 4 - } - } - oPos++; - nPos++; + if (newContext.attribute(nPos).numValues() < originalContext.attribute(oPos).numValues()) { + return false; // rule 3 } - return true; // all checks clear - } - - - - /** - * Resets this classifier. It must be similar to - * starting a new classifier from scratch. <br><br> - * - * The reason for ...Impl methods: ease programmer burden by not requiring - * them to remember calls to super in overridden methods. - * Note that this will produce compiler errors if not overridden. - */ - public abstract void resetLearningImpl(); - - /** - * Trains this classifier incrementally using the given instance.<br><br> - * - * The reason for ...Impl methods: ease programmer burden by not requiring - * them to remember calls to super in overridden methods. - * Note that this will produce compiler errors if not overridden. - * - * @param inst the instance to be used for training - */ - public abstract void trainOnInstanceImpl(Instance inst); - - /** - * Gets the current measurements of this classifier.<br><br> - * - * The reason for ...Impl methods: ease programmer burden by not requiring - * them to remember calls to super in overridden methods. - * Note that this will produce compiler errors if not overridden. - * - * @return an array of measurements to be used in evaluation tasks - */ - protected abstract Measurement[] getModelMeasurementsImpl(); - - /** - * Returns a string representation of the model. - * - * @param out the stringbuilder to add the description - * @param indent the number of characters to indent - */ - public abstract void getModelDescription(StringBuilder out, int indent); - - /** - * Gets the index of the attribute in the instance, - * given the index of the attribute in the learner. - * - * @param index the index of the attribute in the learner - * @return the index in the instance - */ - protected static int modelAttIndexToInstanceAttIndex(int index) { - return index; //inst.classIndex() > index ? index : index + 1; + } else { + assert (originalContext.attribute(oPos).isNumeric()); + if (!newContext.attribute(nPos).isNumeric()) { + return false; // rule 4 + } + } + oPos++; + nPos++; } + return true; // all checks clear + } + + /** + * Resets this classifier. It must be similar to starting a new classifier + * from scratch. <br> + * <br> + * + * The reason for ...Impl methods: ease programmer burden by not requiring + * them to remember calls to super in overridden methods. Note that this will + * produce compiler errors if not overridden. + */ + public abstract void resetLearningImpl(); + + /** + * Trains this classifier incrementally using the given instance.<br> + * <br> + * + * The reason for ...Impl methods: ease programmer burden by not requiring + * them to remember calls to super in overridden methods. Note that this will + * produce compiler errors if not overridden. + * + * @param inst + * the instance to be used for training + */ + public abstract void trainOnInstanceImpl(Instance inst); + + /** + * Gets the current measurements of this classifier.<br> + * <br> + * + * The reason for ...Impl methods: ease programmer burden by not requiring + * them to remember calls to super in overridden methods. Note that this will + * produce compiler errors if not overridden. + * + * @return an array of measurements to be used in evaluation tasks + */ + protected abstract Measurement[] getModelMeasurementsImpl(); + + /** + * Returns a string representation of the model. + * + * @param out + * the stringbuilder to add the description + * @param indent + * the number of characters to indent + */ + public abstract void getModelDescription(StringBuilder out, int indent); + + /** + * Gets the index of the attribute in the instance, given the index of the + * attribute in the learner. + * + * @param index + * the index of the attribute in the learner + * @return the index in the instance + */ + protected static int modelAttIndexToInstanceAttIndex(int index) { + return index; // inst.classIndex() > index ? index : index + 1; + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java index efbc918..bdda15a 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Classifier.java @@ -26,52 +26,55 @@ import com.yahoo.labs.samoa.moa.learners.Learner; /** * Classifier interface for incremental classification models. - * + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public interface Classifier extends Learner<Example<Instance>> { - /** - * Gets the classifiers of this ensemble. Returns null if this learner is a - * single learner. - * - * @return an array of the learners of the ensemble - */ - public Classifier[] getSubClassifiers(); + /** + * Gets the classifiers of this ensemble. Returns null if this learner is a + * single learner. + * + * @return an array of the learners of the ensemble + */ + public Classifier[] getSubClassifiers(); - /** - * Produces a copy of this learner. - * - * @return the copy of this learner - */ - public Classifier copy(); + /** + * Produces a copy of this learner. + * + * @return the copy of this learner + */ + public Classifier copy(); - /** - * Gets whether this classifier correctly classifies an instance. Uses - * getVotesForInstance to obtain the prediction and the instance to obtain - * its true class. - * - * - * @param inst the instance to be classified - * @return true if the instance is correctly classified - */ - public boolean correctlyClassifies(Instance inst); + /** + * Gets whether this classifier correctly classifies an instance. Uses + * getVotesForInstance to obtain the prediction and the instance to obtain its + * true class. + * + * + * @param inst + * the instance to be classified + * @return true if the instance is correctly classified + */ + public boolean correctlyClassifies(Instance inst); - /** - * Trains this learner incrementally using the given example. - * - * @param inst the instance to be used for training - */ - public void trainOnInstance(Instance inst); + /** + * Trains this learner incrementally using the given example. + * + * @param inst + * the instance to be used for training + */ + public void trainOnInstance(Instance inst); - /** - * Predicts the class memberships for a given instance. If an instance is - * unclassified, the returned array elements must be all zero. - * - * @param inst the instance to be classified - * @return an array containing the estimated membership probabilities of the - * test instance in each class - */ - public double[] getVotesForInstance(Instance inst); + /** + * Predicts the class memberships for a given instance. If an instance is + * unclassified, the returned array elements must be all zero. + * + * @param inst + * the instance to be classified + * @return an array containing the estimated membership probabilities of the + * test instance in each class + */ + public double[] getVotesForInstance(Instance inst); } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java index 758f5c4..53d86a2 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/Regressor.java @@ -21,11 +21,12 @@ package com.yahoo.labs.samoa.moa.classifiers; */ /** - * Regressor interface for incremental regression models. It is used only in the GUI Regression Tab. - * + * Regressor interface for incremental regression models. It is used only in the + * GUI Regression Tab. + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public interface Regressor { - + } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java index 1ecc9ed..e469c72 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/AttributeSplitSuggestion.java @@ -25,44 +25,45 @@ import com.yahoo.labs.samoa.moa.classifiers.core.conditionaltests.InstanceCondit /** * Class for computing attribute split suggestions given a split test. - * + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public class AttributeSplitSuggestion extends AbstractMOAObject implements Comparable<AttributeSplitSuggestion> { - - private static final long serialVersionUID = 1L; - public InstanceConditionalTest splitTest; + private static final long serialVersionUID = 1L; + + public InstanceConditionalTest splitTest; + + public double[][] resultingClassDistributions; - public double[][] resultingClassDistributions; + public double merit; - public double merit; - - public AttributeSplitSuggestion() {} + public AttributeSplitSuggestion() { + } - public AttributeSplitSuggestion(InstanceConditionalTest splitTest, - double[][] resultingClassDistributions, double merit) { - this.splitTest = splitTest; - this.resultingClassDistributions = resultingClassDistributions.clone(); - this.merit = merit; - } + public AttributeSplitSuggestion(InstanceConditionalTest splitTest, + double[][] resultingClassDistributions, double merit) { + this.splitTest = splitTest; + this.resultingClassDistributions = resultingClassDistributions.clone(); + this.merit = merit; + } - public int numSplits() { - return this.resultingClassDistributions.length; - } + public int numSplits() { + return this.resultingClassDistributions.length; + } - public double[] resultingClassDistributionFromSplit(int splitIndex) { - return this.resultingClassDistributions[splitIndex].clone(); - } + public double[] resultingClassDistributionFromSplit(int splitIndex) { + return this.resultingClassDistributions[splitIndex].clone(); + } - @Override - public int compareTo(AttributeSplitSuggestion comp) { - return Double.compare(this.merit, comp.merit); - } + @Override + public int compareTo(AttributeSplitSuggestion comp) { + return Double.compare(this.merit, comp.merit); + } - @Override - public void getDescription(StringBuilder sb, int indent) { - // do nothing - } + @Override + public void getDescription(StringBuilder sb, int indent) { + // do nothing + } } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java index d6adc2e..a6cdf80 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/AttributeClassObserver.java @@ -25,49 +25,57 @@ import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; import com.yahoo.labs.samoa.moa.options.OptionHandler; /** - * Interface for observing the class data distribution for an attribute. - * This observer monitors the class distribution of a given attribute. - * Used in naive Bayes and decision trees to monitor data statistics on leaves. - * + * Interface for observing the class data distribution for an attribute. This + * observer monitors the class distribution of a given attribute. Used in naive + * Bayes and decision trees to monitor data statistics on leaves. + * * @author Richard Kirkby ([email protected]) - * @version $Revision: 7 $ + * @version $Revision: 7 $ */ public interface AttributeClassObserver extends OptionHandler { - /** - * Updates statistics of this observer given an attribute value, a class - * and the weight of the instance observed - * - * @param attVal the value of the attribute - * @param classVal the class - * @param weight the weight of the instance - */ - public void observeAttributeClass(double attVal, int classVal, double weight); + /** + * Updates statistics of this observer given an attribute value, a class and + * the weight of the instance observed + * + * @param attVal + * the value of the attribute + * @param classVal + * the class + * @param weight + * the weight of the instance + */ + public void observeAttributeClass(double attVal, int classVal, double weight); - /** - * Gets the probability for an attribute value given a class - * - * @param attVal the attribute value - * @param classVal the class - * @return probability for an attribute value given a class - */ - public double probabilityOfAttributeValueGivenClass(double attVal, - int classVal); + /** + * Gets the probability for an attribute value given a class + * + * @param attVal + * the attribute value + * @param classVal + * the class + * @return probability for an attribute value given a class + */ + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal); - /** - * Gets the best split suggestion given a criterion and a class distribution - * - * @param criterion the split criterion to use - * @param preSplitDist the class distribution before the split - * @param attIndex the attribute index - * @param binaryOnly true to use binary splits - * @return suggestion of best attribute split - */ - public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( - SplitCriterion criterion, double[] preSplitDist, int attIndex, - boolean binaryOnly); + /** + * Gets the best split suggestion given a criterion and a class distribution + * + * @param criterion + * the split criterion to use + * @param preSplitDist + * the class distribution before the split + * @param attIndex + * the attribute index + * @param binaryOnly + * true to use binary splits + * @return suggestion of best attribute split + */ + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( + SplitCriterion criterion, double[] preSplitDist, int attIndex, + boolean binaryOnly); + public void observeAttributeTarget(double attVal, double target); - public void observeAttributeTarget(double attVal, double target); - } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java index e9bb2f9..2b45209 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserver.java @@ -30,154 +30,155 @@ import com.yahoo.labs.samoa.moa.options.AbstractOptionHandler; import com.yahoo.labs.samoa.moa.tasks.TaskMonitor; /** - * Class for observing the class data distribution for a numeric attribute using a binary tree. - * This observer monitors the class distribution of a given attribute. - * Used in naive Bayes and decision trees to monitor data statistics on leaves. - * + * Class for observing the class data distribution for a numeric attribute using + * a binary tree. This observer monitors the class distribution of a given + * attribute. Used in naive Bayes and decision trees to monitor data statistics + * on leaves. + * * @author Richard Kirkby ([email protected]) * @version $Revision: 7 $ */ public class BinaryTreeNumericAttributeClassObserver extends AbstractOptionHandler - implements NumericAttributeClassObserver { - - private static final long serialVersionUID = 1L; + implements NumericAttributeClassObserver { - public class Node implements Serializable { + private static final long serialVersionUID = 1L; - private static final long serialVersionUID = 1L; + public class Node implements Serializable { - public double cut_point; + private static final long serialVersionUID = 1L; - public DoubleVector classCountsLeft = new DoubleVector(); + public double cut_point; - public DoubleVector classCountsRight = new DoubleVector(); + public DoubleVector classCountsLeft = new DoubleVector(); - public Node left; + public DoubleVector classCountsRight = new DoubleVector(); - public Node right; + public Node left; - public Node(double val, int label, double weight) { - this.cut_point = val; - this.classCountsLeft.addToValue(label, weight); - } + public Node right; - public void insertValue(double val, int label, double weight) { - if (val == this.cut_point) { - this.classCountsLeft.addToValue(label, weight); - } else if (val <= this.cut_point) { - this.classCountsLeft.addToValue(label, weight); - if (this.left == null) { - this.left = new Node(val, label, weight); - } else { - this.left.insertValue(val, label, weight); - } - } else { // val > cut_point - this.classCountsRight.addToValue(label, weight); - if (this.right == null) { - this.right = new Node(val, label, weight); - } else { - this.right.insertValue(val, label, weight); - } - } - } + public Node(double val, int label, double weight) { + this.cut_point = val; + this.classCountsLeft.addToValue(label, weight); } - public Node root = null; - - @Override - public void observeAttributeClass(double attVal, int classVal, double weight) { - if (Double.isNaN(attVal)) { //Instance.isMissingValue(attVal) + public void insertValue(double val, int label, double weight) { + if (val == this.cut_point) { + this.classCountsLeft.addToValue(label, weight); + } else if (val <= this.cut_point) { + this.classCountsLeft.addToValue(label, weight); + if (this.left == null) { + this.left = new Node(val, label, weight); } else { - if (this.root == null) { - this.root = new Node(attVal, classVal, weight); - } else { - this.root.insertValue(attVal, classVal, weight); - } + this.left.insertValue(val, label, weight); } + } else { // val > cut_point + this.classCountsRight.addToValue(label, weight); + if (this.right == null) { + this.right = new Node(val, label, weight); + } else { + this.right.insertValue(val, label, weight); + } + } } - - @Override - public double probabilityOfAttributeValueGivenClass(double attVal, - int classVal) { - // TODO: NaiveBayes broken until implemented - return 0.0; + } + + public Node root = null; + + @Override + public void observeAttributeClass(double attVal, int classVal, double weight) { + if (Double.isNaN(attVal)) { // Instance.isMissingValue(attVal) + } else { + if (this.root == null) { + this.root = new Node(attVal, classVal, weight); + } else { + this.root.insertValue(attVal, classVal, weight); + } } - - @Override - public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( - SplitCriterion criterion, double[] preSplitDist, int attIndex, - boolean binaryOnly) { - return searchForBestSplitOption(this.root, null, null, null, null, false, - criterion, preSplitDist, attIndex); + } + + @Override + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal) { + // TODO: NaiveBayes broken until implemented + return 0.0; + } + + @Override + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( + SplitCriterion criterion, double[] preSplitDist, int attIndex, + boolean binaryOnly) { + return searchForBestSplitOption(this.root, null, null, null, null, false, + criterion, preSplitDist, attIndex); + } + + protected AttributeSplitSuggestion searchForBestSplitOption( + Node currentNode, AttributeSplitSuggestion currentBestOption, + double[] actualParentLeft, + double[] parentLeft, double[] parentRight, boolean leftChild, + SplitCriterion criterion, double[] preSplitDist, int attIndex) { + if (currentNode == null) { + return currentBestOption; } - - protected AttributeSplitSuggestion searchForBestSplitOption( - Node currentNode, AttributeSplitSuggestion currentBestOption, - double[] actualParentLeft, - double[] parentLeft, double[] parentRight, boolean leftChild, - SplitCriterion criterion, double[] preSplitDist, int attIndex) { - if (currentNode == null) { - return currentBestOption; - } - DoubleVector leftDist = new DoubleVector(); - DoubleVector rightDist = new DoubleVector(); - if (parentLeft == null) { - leftDist.addValues(currentNode.classCountsLeft); - rightDist.addValues(currentNode.classCountsRight); - } else { - leftDist.addValues(parentLeft); - rightDist.addValues(parentRight); - if (leftChild) { - //get the exact statistics of the parent value - DoubleVector exactParentDist = new DoubleVector(); - exactParentDist.addValues(actualParentLeft); - exactParentDist.subtractValues(currentNode.classCountsLeft); - exactParentDist.subtractValues(currentNode.classCountsRight); - - // move the subtrees - leftDist.subtractValues(currentNode.classCountsRight); - rightDist.addValues(currentNode.classCountsRight); - - // move the exact value from the parent - rightDist.addValues(exactParentDist); - leftDist.subtractValues(exactParentDist); - - } else { - leftDist.addValues(currentNode.classCountsLeft); - rightDist.subtractValues(currentNode.classCountsLeft); - } - } - double[][] postSplitDists = new double[][]{leftDist.getArrayRef(), - rightDist.getArrayRef()}; - double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); - if ((currentBestOption == null) || (merit > currentBestOption.merit)) { - currentBestOption = new AttributeSplitSuggestion( - new NumericAttributeBinaryTest(attIndex, - currentNode.cut_point, true), postSplitDists, merit); - - } - currentBestOption = searchForBestSplitOption(currentNode.left, - currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], true, - criterion, preSplitDist, attIndex); - currentBestOption = searchForBestSplitOption(currentNode.right, - currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], false, - criterion, preSplitDist, attIndex); - return currentBestOption; + DoubleVector leftDist = new DoubleVector(); + DoubleVector rightDist = new DoubleVector(); + if (parentLeft == null) { + leftDist.addValues(currentNode.classCountsLeft); + rightDist.addValues(currentNode.classCountsRight); + } else { + leftDist.addValues(parentLeft); + rightDist.addValues(parentRight); + if (leftChild) { + // get the exact statistics of the parent value + DoubleVector exactParentDist = new DoubleVector(); + exactParentDist.addValues(actualParentLeft); + exactParentDist.subtractValues(currentNode.classCountsLeft); + exactParentDist.subtractValues(currentNode.classCountsRight); + + // move the subtrees + leftDist.subtractValues(currentNode.classCountsRight); + rightDist.addValues(currentNode.classCountsRight); + + // move the exact value from the parent + rightDist.addValues(exactParentDist); + leftDist.subtractValues(exactParentDist); + + } else { + leftDist.addValues(currentNode.classCountsLeft); + rightDist.subtractValues(currentNode.classCountsLeft); + } } + double[][] postSplitDists = new double[][] { leftDist.getArrayRef(), + rightDist.getArrayRef() }; + double merit = criterion.getMeritOfSplit(preSplitDist, postSplitDists); + if ((currentBestOption == null) || (merit > currentBestOption.merit)) { + currentBestOption = new AttributeSplitSuggestion( + new NumericAttributeBinaryTest(attIndex, + currentNode.cut_point, true), postSplitDists, merit); - @Override - public void getDescription(StringBuilder sb, int indent) { - // TODO Auto-generated method stub } + currentBestOption = searchForBestSplitOption(currentNode.left, + currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], true, + criterion, preSplitDist, attIndex); + currentBestOption = searchForBestSplitOption(currentNode.right, + currentBestOption, currentNode.classCountsLeft.getArrayRef(), postSplitDists[0], postSplitDists[1], false, + criterion, preSplitDist, attIndex); + return currentBestOption; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + // TODO Auto-generated method stub + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + // TODO Auto-generated method stub + } + + @Override + public void observeAttributeTarget(double attVal, double target) { + throw new UnsupportedOperationException("Not supported yet."); + } - @Override - protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { - // TODO Auto-generated method stub - } - - @Override - public void observeAttributeTarget(double attVal, double target) { - throw new UnsupportedOperationException("Not supported yet."); - } - } http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java index a68cad9..eeffebd 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/BinaryTreeNumericAttributeClassObserverRegression.java @@ -1,4 +1,3 @@ - package com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers; /* @@ -21,7 +20,6 @@ package com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers; * #L% */ - import java.io.Serializable; import com.yahoo.labs.samoa.moa.classifiers.core.AttributeSplitSuggestion; import com.yahoo.labs.samoa.moa.classifiers.core.splitcriteria.SplitCriterion; @@ -30,119 +28,128 @@ import com.yahoo.labs.samoa.moa.options.AbstractOptionHandler; import com.yahoo.labs.samoa.moa.tasks.TaskMonitor; /** - * Class for observing the class data distribution for a numeric attribute using a binary tree. - * This observer monitors the class distribution of a given attribute. - * - * <p>Learning Adaptive Model Rules from High-Speed Data Streams, ECML 2013, E. Almeida, C. Ferreira, P. Kosina and J. Gama; </p> - * + * Class for observing the class data distribution for a numeric attribute using + * a binary tree. This observer monitors the class distribution of a given + * attribute. + * + * <p> + * Learning Adaptive Model Rules from High-Speed Data Streams, ECML 2013, E. + * Almeida, C. Ferreira, P. Kosina and J. Gama; + * </p> + * * @author E. Almeida, J. Gama * @version $Revision: 2$ */ public class BinaryTreeNumericAttributeClassObserverRegression extends AbstractOptionHandler - implements NumericAttributeClassObserver { - - public static final long serialVersionUID = 1L; - - public class Node implements Serializable { - - private static final long serialVersionUID = 1L; - - public double cut_point; - - public double[] lessThan; //This array maintains statistics for the instance reaching the node with attribute values less than or iqual to the cutpoint. - - public double[] greaterThan; //This array maintains statistics for the instance reaching the node with attribute values greater than to the cutpoint. - - public Node left; - - public Node right; - - public Node(double val, double target) { - this.cut_point = val; - this.lessThan = new double[3]; - this.greaterThan = new double[3]; - this.lessThan[0] = target; //The sum of their target attribute values. - this.lessThan[1] = target * target; //The sum of the squared target attribute values. - this.lessThan[2] = 1.0; //A counter of the number of instances that have reached the node. - this.greaterThan[0] = 0.0; - this.greaterThan[1] = 0.0; - this.greaterThan[2] = 0.0; - } + implements NumericAttributeClassObserver { - public void insertValue(double val, double target) { - if (val == this.cut_point) { - this.lessThan[0] = this.lessThan[0] + target; - this.lessThan[1] = this.lessThan[1] + (target * target); - this.lessThan[2] = this.lessThan[2] + 1; - } else if (val <= this.cut_point) { - this.lessThan[0] = this.lessThan[0] + target; - this.lessThan[1] = this.lessThan[1] + (target * target); - this.lessThan[2] = this.lessThan[2] + 1; - if (this.left == null) { - this.left = new Node(val, target); - } else { - this.left.insertValue(val, target); - } - } else { - this.greaterThan[0] = this.greaterThan[0] + target; - this.greaterThan[1] = this.greaterThan[1] + (target*target); - this.greaterThan[2] = this.greaterThan[2] + 1; - if (this.right == null) { - - this.right = new Node(val, target); - } else { - this.right.insertValue(val, target); - } - } - } - } + public static final long serialVersionUID = 1L; - public Node root1 = null; - - public void observeAttributeTarget(double attVal, double target){ - if (!Double.isNaN(attVal)) { - if (this.root1 == null) { - this.root1 = new Node(attVal, target); - } else { - this.root1.insertValue(attVal, target); - } - } - } + public class Node implements Serializable { - @Override - public void observeAttributeClass(double attVal, int classVal, double weight) { - - } + private static final long serialVersionUID = 1L; - @Override - public double probabilityOfAttributeValueGivenClass(double attVal, - int classVal) { - return 0.0; - } + public double cut_point; - @Override - public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( - SplitCriterion criterion, double[] preSplitDist, int attIndex, - boolean binaryOnly) { - return searchForBestSplitOption(this.root1, null, null, null, null, false, - criterion, preSplitDist, attIndex); - } + public double[] lessThan; // This array maintains statistics for the + // instance reaching the node with attribute + // values less than or iqual to the cutpoint. + + public double[] greaterThan; // This array maintains statistics for the + // instance reaching the node with attribute + // values greater than to the cutpoint. + + public Node left; - protected AttributeSplitSuggestion searchForBestSplitOption( - Node currentNode, AttributeSplitSuggestion currentBestOption, - double[] actualParentLeft, - double[] parentLeft, double[] parentRight, boolean leftChild, - SplitCriterion criterion, double[] preSplitDist, int attIndex) { - - return currentBestOption; + public Node right; + + public Node(double val, double target) { + this.cut_point = val; + this.lessThan = new double[3]; + this.greaterThan = new double[3]; + this.lessThan[0] = target; // The sum of their target attribute values. + this.lessThan[1] = target * target; // The sum of the squared target + // attribute values. + this.lessThan[2] = 1.0; // A counter of the number of instances that have + // reached the node. + this.greaterThan[0] = 0.0; + this.greaterThan[1] = 0.0; + this.greaterThan[2] = 0.0; } - @Override - public void getDescription(StringBuilder sb, int indent) { + public void insertValue(double val, double target) { + if (val == this.cut_point) { + this.lessThan[0] = this.lessThan[0] + target; + this.lessThan[1] = this.lessThan[1] + (target * target); + this.lessThan[2] = this.lessThan[2] + 1; + } else if (val <= this.cut_point) { + this.lessThan[0] = this.lessThan[0] + target; + this.lessThan[1] = this.lessThan[1] + (target * target); + this.lessThan[2] = this.lessThan[2] + 1; + if (this.left == null) { + this.left = new Node(val, target); + } else { + this.left.insertValue(val, target); + } + } else { + this.greaterThan[0] = this.greaterThan[0] + target; + this.greaterThan[1] = this.greaterThan[1] + (target * target); + this.greaterThan[2] = this.greaterThan[2] + 1; + if (this.right == null) { + + this.right = new Node(val, target); + } else { + this.right.insertValue(val, target); + } + } } + } + + public Node root1 = null; - @Override - protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + public void observeAttributeTarget(double attVal, double target) { + if (!Double.isNaN(attVal)) { + if (this.root1 == null) { + this.root1 = new Node(attVal, target); + } else { + this.root1.insertValue(attVal, target); + } } + } + + @Override + public void observeAttributeClass(double attVal, int classVal, double weight) { + + } + + @Override + public double probabilityOfAttributeValueGivenClass(double attVal, + int classVal) { + return 0.0; + } + + @Override + public AttributeSplitSuggestion getBestEvaluatedSplitSuggestion( + SplitCriterion criterion, double[] preSplitDist, int attIndex, + boolean binaryOnly) { + return searchForBestSplitOption(this.root1, null, null, null, null, false, + criterion, preSplitDist, attIndex); + } + + protected AttributeSplitSuggestion searchForBestSplitOption( + Node currentNode, AttributeSplitSuggestion currentBestOption, + double[] actualParentLeft, + double[] parentLeft, double[] parentRight, boolean leftChild, + SplitCriterion criterion, double[] preSplitDist, int attIndex) { + + return currentBestOption; + } + + @Override + public void getDescription(StringBuilder sb, int indent) { + } + + @Override + protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) { + } } - http://git-wip-us.apache.org/repos/asf/incubator-samoa/blob/23a35dbe/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java ---------------------------------------------------------------------- diff --git a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java index e756fcd..fe16447 100644 --- a/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java +++ b/samoa-api/src/main/java/com/yahoo/labs/samoa/moa/classifiers/core/attributeclassobservers/DiscreteAttributeClassObserver.java @@ -21,14 +21,14 @@ package com.yahoo.labs.samoa.moa.classifiers.core.attributeclassobservers; */ /** - * Interface for observing the class data distribution for a discrete (nominal) attribute. - * This observer monitors the class distribution of a given attribute. - * Used in naive Bayes and decision trees to monitor data statistics on leaves. - * + * Interface for observing the class data distribution for a discrete (nominal) + * attribute. This observer monitors the class distribution of a given + * attribute. Used in naive Bayes and decision trees to monitor data statistics + * on leaves. + * * @author Richard Kirkby ([email protected]) - * @version $Revision: 7 $ + * @version $Revision: 7 $ */ public interface DiscreteAttributeClassObserver extends AttributeClassObserver { - }
