http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java new file mode 100644 index 0000000..9f84e9c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java @@ -0,0 +1,422 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.data.conditions.Condition; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; +import org.apache.mahout.classifier.df.split.IgSplit; +import org.apache.mahout.classifier.df.split.OptIgSplit; +import org.apache.mahout.classifier.df.split.RegressionSplit; +import org.apache.mahout.classifier.df.split.Split; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Random; + +/** + * Builds a classification tree or regression tree<br> + * A classification tree is built when the criterion variable is the categorical attribute.<br> + * A regression tree is built when the criterion variable is the numerical attribute. + */ +@Deprecated +public class DecisionTreeBuilder implements TreeBuilder { + + private static final Logger log = LoggerFactory.getLogger(DecisionTreeBuilder.class); + + private static final int[] NO_ATTRIBUTES = new int[0]; + private static final double EPSILON = 1.0e-6; + + /** + * indicates which CATEGORICAL attributes have already been selected in the parent nodes + */ + private boolean[] selected; + /** + * number of attributes to select randomly at each node + */ + private int m; + /** + * IgSplit implementation + */ + private IgSplit igSplit; + /** + * tree is complemented + */ + private boolean complemented = true; + /** + * minimum number for split + */ + private double minSplitNum = 2.0; + /** + * minimum proportion of the total variance for split + */ + private double minVarianceProportion = 1.0e-3; + /** + * full set data + */ + private Data fullSet; + /** + * minimum variance for split + */ + private double minVariance = Double.NaN; + + public void setM(int m) { + this.m = m; + } + + public void setIgSplit(IgSplit igSplit) { + this.igSplit = igSplit; + } + + public void setComplemented(boolean complemented) { + this.complemented = complemented; + } + + public void setMinSplitNum(int minSplitNum) { + this.minSplitNum = minSplitNum; + } + + public void setMinVarianceProportion(double minVarianceProportion) { + this.minVarianceProportion = minVarianceProportion; + } + + @Override + public Node build(Random rng, Data data) { + if (selected == null) { + selected = new boolean[data.getDataset().nbAttributes()]; + selected[data.getDataset().getLabelId()] = true; // never select the label + } + if (m == 0) { + // set default m + double e = data.getDataset().nbAttributes() - 1; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + m = (int) Math.ceil(e / 3.0); + } else { + // classification + m = (int) Math.ceil(Math.sqrt(e)); + } + } + + if (data.isEmpty()) { + return new Leaf(Double.NaN); + } + + double sum = 0.0; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + // sum and sum squared of a label is computed + double sumSquared = 0.0; + for (int i = 0; i < data.size(); i++) { + double label = data.getDataset().getLabel(data.get(i)); + sum += label; + sumSquared += label * label; + } + + // computes the variance + double var = sumSquared - (sum * sum) / data.size(); + + // computes the minimum variance + if (Double.compare(minVariance, Double.NaN) == 0) { + minVariance = var / data.size() * minVarianceProportion; + log.debug("minVariance:{}", minVariance); + } + + // variance is compared with minimum variance + if ((var / data.size()) < minVariance) { + log.debug("variance({}) < minVariance({}) Leaf({})", var / data.size(), minVariance, sum / data.size()); + return new Leaf(sum / data.size()); + } + } else { + // classification + if (isIdentical(data)) { + return new Leaf(data.majorityLabel(rng)); + } + if (data.identicalLabel()) { + return new Leaf(data.getDataset().getLabel(data.get(0))); + } + } + + // store full set data + if (fullSet == null) { + fullSet = data; + } + + int[] attributes = randomAttributes(rng, selected, m); + if (attributes == null || attributes.length == 0) { + // we tried all the attributes and could not split the data anymore + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + label = sum / data.size(); + } else { + // classification + label = data.majorityLabel(rng); + } + log.warn("attribute which can be selected is not found Leaf({})", label); + return new Leaf(label); + } + + if (igSplit == null) { + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + // regression + igSplit = new RegressionSplit(); + } else { + // classification + igSplit = new OptIgSplit(); + } + } + + // find the best split + Split best = null; + for (int attr : attributes) { + Split split = igSplit.computeSplit(data, attr); + if (best == null || best.getIg() < split.getIg()) { + best = split; + } + } + + // information gain is near to zero. + if (best.getIg() < EPSILON) { + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("ig is near to zero Leaf({})", label); + return new Leaf(label); + } + + log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg()); + + boolean alreadySelected = selected[best.getAttr()]; + if (alreadySelected) { + // attribute already selected + log.warn("attribute {} already selected in a parent node", best.getAttr()); + } + + Node childNode; + if (data.getDataset().isNumerical(best.getAttr())) { + boolean[] temp = null; + + Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit())); + Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit())); + + if (loSubset.isEmpty() || hiSubset.isEmpty()) { + // the selected attribute did not change the data, avoid using it in the child notes + selected[best.getAttr()] = true; + } else { + // the data changed, so we can unselect all previousely selected NUMERICAL attributes + temp = selected; + selected = cloneCategoricalAttributes(data.getDataset(), selected); + } + + // size of the subset is less than the minSpitNum + if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) { + // branch is not split + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("branch is not split Leaf({})", label); + return new Leaf(label); + } + + Node loChild = build(rng, loSubset); + Node hiChild = build(rng, hiSubset); + + // restore the selection state of the attributes + if (temp != null) { + selected = temp; + } else { + selected[best.getAttr()] = alreadySelected; + } + + childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild); + } else { // CATEGORICAL attribute + double[] values = data.values(best.getAttr()); + + // tree is complemented + Collection<Double> subsetValues = null; + if (complemented) { + subsetValues = new HashSet<>(); + for (double value : values) { + subsetValues.add(value); + } + values = fullSet.values(best.getAttr()); + } + + int cnt = 0; + Data[] subsets = new Data[values.length]; + for (int index = 0; index < values.length; index++) { + if (complemented && !subsetValues.contains(values[index])) { + continue; + } + subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index])); + if (subsets[index].size() >= minSplitNum) { + cnt++; + } + } + + // size of the subset is less than the minSpitNum + if (cnt < 2) { + // branch is not split + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("branch is not split Leaf({})", label); + return new Leaf(label); + } + + selected[best.getAttr()] = true; + + Node[] children = new Node[values.length]; + for (int index = 0; index < values.length; index++) { + if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) { + // tree is complemented + double label; + if (data.getDataset().isNumerical(data.getDataset().getLabelId())) { + label = sum / data.size(); + } else { + label = data.majorityLabel(rng); + } + log.debug("complemented Leaf({})", label); + children[index] = new Leaf(label); + continue; + } + children[index] = build(rng, subsets[index]); + } + + selected[best.getAttr()] = alreadySelected; + + childNode = new CategoricalNode(best.getAttr(), values, children); + } + + return childNode; + } + + /** + * checks if all the vectors have identical attribute values. Ignore selected attributes. + * + * @return true is all the vectors are identical or the data is empty<br> + * false otherwise + */ + private boolean isIdentical(Data data) { + if (data.isEmpty()) { + return true; + } + + Instance instance = data.get(0); + for (int attr = 0; attr < selected.length; attr++) { + if (selected[attr]) { + continue; + } + + for (int index = 1; index < data.size(); index++) { + if (data.get(index).get(attr) != instance.get(attr)) { + return false; + } + } + } + + return true; + } + + /** + * Make a copy of the selection state of the attributes, unselect all numerical attributes + * + * @param selected selection state to clone + * @return cloned selection state + */ + private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) { + boolean[] cloned = new boolean[selected.length]; + + for (int i = 0; i < selected.length; i++) { + cloned[i] = !dataset.isNumerical(i) && selected[i]; + } + cloned[dataset.getLabelId()] = true; + + return cloned; + } + + /** + * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes + * + * @param rng random-numbers generator + * @param selected attributes' state (selected or not) + * @param m number of attributes to choose + * @return list of selected attributes' indices, or null if all attributes have already been selected + */ + private static int[] randomAttributes(Random rng, boolean[] selected, int m) { + int nbNonSelected = 0; // number of non selected attributes + for (boolean sel : selected) { + if (!sel) { + nbNonSelected++; + } + } + + if (nbNonSelected == 0) { + log.warn("All attributes are selected !"); + return NO_ATTRIBUTES; + } + + int[] result; + if (nbNonSelected <= m) { + // return all non selected attributes + result = new int[nbNonSelected]; + int index = 0; + for (int attr = 0; attr < selected.length; attr++) { + if (!selected[attr]) { + result[index++] = attr; + } + } + } else { + result = new int[m]; + for (int index = 0; index < m; index++) { + // randomly choose a "non selected" attribute + int rind; + do { + rind = rng.nextInt(selected.length); + } while (selected[rind]); + + result[index] = rind; + selected[rind] = true; // temporarily set the chosen attribute to be selected + } + + // the chosen attributes are not yet selected + for (int attr : result) { + selected[attr] = false; + } + } + + return result; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java new file mode 100644 index 0000000..3392fb1 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java @@ -0,0 +1,253 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.data.conditions.Condition; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; +import org.apache.mahout.classifier.df.split.IgSplit; +import org.apache.mahout.classifier.df.split.OptIgSplit; +import org.apache.mahout.classifier.df.split.Split; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Random; + +/** + * Builds a Decision Tree <br> + * Based on the algorithm described in the "Decision Trees" tutorials by Andrew W. Moore, available at:<br> + * <br> + * http://www.cs.cmu.edu/~awm/tutorials + * <br><br> + * This class can be used when the criterion variable is the categorical attribute. + */ +@Deprecated +public class DefaultTreeBuilder implements TreeBuilder { + + private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class); + + private static final int[] NO_ATTRIBUTES = new int[0]; + + /** + * indicates which CATEGORICAL attributes have already been selected in the parent nodes + */ + private boolean[] selected; + /** + * number of attributes to select randomly at each node + */ + private int m = 1; + /** + * IgSplit implementation + */ + private final IgSplit igSplit; + + public DefaultTreeBuilder() { + igSplit = new OptIgSplit(); + } + + public void setM(int m) { + this.m = m; + } + + @Override + public Node build(Random rng, Data data) { + + if (selected == null) { + selected = new boolean[data.getDataset().nbAttributes()]; + selected[data.getDataset().getLabelId()] = true; // never select the label + } + + if (data.isEmpty()) { + return new Leaf(-1); + } + if (isIdentical(data)) { + return new Leaf(data.majorityLabel(rng)); + } + if (data.identicalLabel()) { + return new Leaf(data.getDataset().getLabel(data.get(0))); + } + + int[] attributes = randomAttributes(rng, selected, m); + if (attributes == null || attributes.length == 0) { + // we tried all the attributes and could not split the data anymore + return new Leaf(data.majorityLabel(rng)); + } + + // find the best split + Split best = null; + for (int attr : attributes) { + Split split = igSplit.computeSplit(data, attr); + if (best == null || best.getIg() < split.getIg()) { + best = split; + } + } + + boolean alreadySelected = selected[best.getAttr()]; + if (alreadySelected) { + // attribute already selected + log.warn("attribute {} already selected in a parent node", best.getAttr()); + } + + Node childNode; + if (data.getDataset().isNumerical(best.getAttr())) { + boolean[] temp = null; + + Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit())); + Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit())); + + if (loSubset.isEmpty() || hiSubset.isEmpty()) { + // the selected attribute did not change the data, avoid using it in the child notes + selected[best.getAttr()] = true; + } else { + // the data changed, so we can unselect all previousely selected NUMERICAL attributes + temp = selected; + selected = cloneCategoricalAttributes(data.getDataset(), selected); + } + + Node loChild = build(rng, loSubset); + Node hiChild = build(rng, hiSubset); + + // restore the selection state of the attributes + if (temp != null) { + selected = temp; + } else { + selected[best.getAttr()] = alreadySelected; + } + + childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild); + } else { // CATEGORICAL attribute + selected[best.getAttr()] = true; + + double[] values = data.values(best.getAttr()); + Node[] children = new Node[values.length]; + + for (int index = 0; index < values.length; index++) { + Data subset = data.subset(Condition.equals(best.getAttr(), values[index])); + children[index] = build(rng, subset); + } + + selected[best.getAttr()] = alreadySelected; + + childNode = new CategoricalNode(best.getAttr(), values, children); + } + + return childNode; + } + + /** + * checks if all the vectors have identical attribute values. Ignore selected attributes. + * + * @return true is all the vectors are identical or the data is empty<br> + * false otherwise + */ + private boolean isIdentical(Data data) { + if (data.isEmpty()) { + return true; + } + + Instance instance = data.get(0); + for (int attr = 0; attr < selected.length; attr++) { + if (selected[attr]) { + continue; + } + + for (int index = 1; index < data.size(); index++) { + if (data.get(index).get(attr) != instance.get(attr)) { + return false; + } + } + } + + return true; + } + + + /** + * Make a copy of the selection state of the attributes, unselect all numerical attributes + * + * @param selected selection state to clone + * @return cloned selection state + */ + private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) { + boolean[] cloned = new boolean[selected.length]; + + for (int i = 0; i < selected.length; i++) { + cloned[i] = !dataset.isNumerical(i) && selected[i]; + } + + return cloned; + } + + /** + * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes + * + * @param rng random-numbers generator + * @param selected attributes' state (selected or not) + * @param m number of attributes to choose + * @return list of selected attributes' indices, or null if all attributes have already been selected + */ + protected static int[] randomAttributes(Random rng, boolean[] selected, int m) { + int nbNonSelected = 0; // number of non selected attributes + for (boolean sel : selected) { + if (!sel) { + nbNonSelected++; + } + } + + if (nbNonSelected == 0) { + log.warn("All attributes are selected !"); + return NO_ATTRIBUTES; + } + + int[] result; + if (nbNonSelected <= m) { + // return all non selected attributes + result = new int[nbNonSelected]; + int index = 0; + for (int attr = 0; attr < selected.length; attr++) { + if (!selected[attr]) { + result[index++] = attr; + } + } + } else { + result = new int[m]; + for (int index = 0; index < m; index++) { + // randomly choose a "non selected" attribute + int rind; + do { + rind = rng.nextInt(selected.length); + } while (selected[rind]); + + result[index] = rind; + selected[rind] = true; // temporarily set the chosen attribute to be selected + } + + // the chosen attributes are not yet selected + for (int attr : result) { + selected[attr] = false; + } + } + + return result; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java new file mode 100644 index 0000000..bf686a4 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.node.Node; + +import java.util.Random; + +/** + * Abstract base class for TreeBuilders + */ +@Deprecated +public interface TreeBuilder { + + /** + * Builds a Decision tree using the training data + * + * @param rng + * random-numbers generator + * @param data + * training data + * @return root Node + */ + Node build(Random rng, Data data); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java new file mode 100644 index 0000000..77e5ed5 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java @@ -0,0 +1,281 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import org.apache.mahout.classifier.df.data.conditions.Condition; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Random; + +/** + * Holds a list of vectors and their corresponding Dataset. contains various operations that deals with the + * vectors (subset, count,...) + * + */ +@Deprecated +public class Data implements Cloneable { + + private final List<Instance> instances; + + private final Dataset dataset; + + public Data(Dataset dataset) { + this.dataset = dataset; + this.instances = new ArrayList<>(); + } + + public Data(Dataset dataset, List<Instance> instances) { + this.dataset = dataset; + this.instances = new ArrayList<>(instances); + } + + /** + * @return the number of elements + */ + public int size() { + return instances.size(); + } + + /** + * @return true if this data contains no element + */ + public boolean isEmpty() { + return instances.isEmpty(); + } + + /** + * @param v + * element whose presence in this list if to be searched + * @return true is this data contains the specified element. + */ + public boolean contains(Instance v) { + return instances.contains(v); + } + + /** + * Returns the element at the specified position + * + * @param index + * index of element to return + * @return the element at the specified position + * @throws IndexOutOfBoundsException + * if the index is out of range + */ + public Instance get(int index) { + return instances.get(index); + } + + /** + * @return the subset from this data that matches the given condition + */ + public Data subset(Condition condition) { + List<Instance> subset = new ArrayList<>(); + + for (Instance instance : instances) { + if (condition.isTrueFor(instance)) { + subset.add(instance); + } + } + + return new Data(dataset, subset); + } + + /** + * if data has N cases, sample N cases at random -but with replacement. + */ + public Data bagging(Random rng) { + int datasize = size(); + List<Instance> bag = new ArrayList<>(datasize); + + for (int i = 0; i < datasize; i++) { + bag.add(instances.get(rng.nextInt(datasize))); + } + + return new Data(dataset, bag); + } + + /** + * if data has N cases, sample N cases at random -but with replacement. + * + * @param sampled + * indicating which instance has been sampled + * + * @return sampled data + */ + public Data bagging(Random rng, boolean[] sampled) { + int datasize = size(); + List<Instance> bag = new ArrayList<>(datasize); + + for (int i = 0; i < datasize; i++) { + int index = rng.nextInt(datasize); + bag.add(instances.get(index)); + sampled[index] = true; + } + + return new Data(dataset, bag); + } + + /** + * Splits the data in two, returns one part, and this gets the rest of the data. <b>VERY SLOW!</b> + */ + public Data rsplit(Random rng, int subsize) { + List<Instance> subset = new ArrayList<>(subsize); + + for (int i = 0; i < subsize; i++) { + subset.add(instances.remove(rng.nextInt(instances.size()))); + } + + return new Data(dataset, subset); + } + + /** + * checks if all the vectors have identical attribute values + * + * @return true is all the vectors are identical or the data is empty<br> + * false otherwise + */ + public boolean isIdentical() { + if (isEmpty()) { + return true; + } + + Instance instance = get(0); + for (int attr = 0; attr < dataset.nbAttributes(); attr++) { + for (int index = 1; index < size(); index++) { + if (get(index).get(attr) != instance.get(attr)) { + return false; + } + } + } + + return true; + } + + /** + * checks if all the vectors have identical label values + */ + public boolean identicalLabel() { + if (isEmpty()) { + return true; + } + + double label = dataset.getLabel(get(0)); + for (int index = 1; index < size(); index++) { + if (dataset.getLabel(get(index)) != label) { + return false; + } + } + + return true; + } + + /** + * finds all distinct values of a given attribute + */ + public double[] values(int attr) { + Collection<Double> result = new HashSet<>(); + + for (Instance instance : instances) { + result.add(instance.get(attr)); + } + + double[] values = new double[result.size()]; + + int index = 0; + for (Double value : result) { + values[index++] = value; + } + + return values; + } + + @Override + public Data clone() { + return new Data(dataset, new ArrayList<>(instances)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Data)) { + return false; + } + + Data data = (Data) obj; + + return instances.equals(data.instances) && dataset.equals(data.dataset); + } + + @Override + public int hashCode() { + return instances.hashCode() + dataset.hashCode(); + } + + /** + * extract the labels of all instances + */ + public double[] extractLabels() { + double[] labels = new double[size()]; + + for (int index = 0; index < labels.length; index++) { + labels[index] = dataset.getLabel(get(index)); + } + + return labels; + } + + /** + * finds the majority label, breaking ties randomly<br> + * This method can be used when the criterion variable is the categorical attribute. + * + * @return the majority label value + */ + public int majorityLabel(Random rng) { + // count the frequency of each label value + int[] counts = new int[dataset.nblabels()]; + + for (int index = 0; index < size(); index++) { + counts[(int) dataset.getLabel(get(index))]++; + } + + // find the label values that appears the most + return DataUtils.maxindex(rng, counts); + } + + /** + * Counts the number of occurrences of each label value<br> + * This method can be used when the criterion variable is the categorical attribute. + * + * @param counts + * will contain the results, supposed to be initialized at 0 + */ + public void countLabels(int[] counts) { + for (int index = 0; index < size(); index++) { + counts[(int) dataset.getLabel(get(index))]++; + } + } + + public Dataset getDataset() { + return dataset; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java new file mode 100644 index 0000000..f1bdc95 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.math.DenseVector; + +import java.util.regex.Pattern; + +/** + * Converts String to Instance using a Dataset + */ +@Deprecated +public class DataConverter { + + private static final Pattern COMMA_SPACE = Pattern.compile("[, ]"); + + private final Dataset dataset; + + public DataConverter(Dataset dataset) { + this.dataset = dataset; + } + + public Instance convert(CharSequence string) { + // all attributes (categorical, numerical, label), ignored + int nball = dataset.nbAttributes() + dataset.getIgnored().length; + + String[] tokens = COMMA_SPACE.split(string); + Preconditions.checkArgument(tokens.length == nball, + "Wrong number of attributes in the string: " + tokens.length + ". Must be " + nball); + + int nbattrs = dataset.nbAttributes(); + DenseVector vector = new DenseVector(nbattrs); + + int aId = 0; + for (int attr = 0; attr < nball; attr++) { + if (!ArrayUtils.contains(dataset.getIgnored(), attr)) { + String token = tokens[attr].trim(); + + if ("?".equals(token)) { + // missing value + return null; + } + + if (dataset.isNumerical(aId)) { + vector.set(aId++, Double.parseDouble(token)); + } else { // CATEGORICAL + vector.set(aId, dataset.valueOf(aId, token)); + aId++; + } + } + } + + return new Instance(vector); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java new file mode 100644 index 0000000..c62dcac --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java @@ -0,0 +1,255 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.data.Dataset.Attribute; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Scanner; +import java.util.Set; +import java.util.regex.Pattern; + +/** + * Converts the input data to a Vector Array using the information given by the Dataset.<br> + * Generates for each line a Vector that contains :<br> + * <ul> + * <li>double parsed value for NUMERICAL attributes</li> + * <li>int value for CATEGORICAL and LABEL attributes</li> + * </ul> + * <br> + * adds an IGNORED first attribute that will contain a unique id for each instance, which is the line number + * of the instance in the input data + */ +@Deprecated +public final class DataLoader { + + private static final Logger log = LoggerFactory.getLogger(DataLoader.class); + + private static final Pattern SEPARATORS = Pattern.compile("[, ]"); + + private DataLoader() {} + + /** + * Converts a comma-separated String to a Vector. + * + * @param attrs + * attributes description + * @param values + * used to convert CATEGORICAL attribute values to Integer + * @return false if there are missing values '?' or NUMERICAL attribute values is not numeric + */ + private static boolean parseString(Attribute[] attrs, Set<String>[] values, CharSequence string, + boolean regression) { + String[] tokens = SEPARATORS.split(string); + Preconditions.checkArgument(tokens.length == attrs.length, + "Wrong number of attributes in the string: " + tokens.length + ". Must be: " + attrs.length); + + // extract tokens and check is there is any missing value + for (int attr = 0; attr < attrs.length; attr++) { + if (!attrs[attr].isIgnored() && "?".equals(tokens[attr])) { + return false; // missing value + } + } + + for (int attr = 0; attr < attrs.length; attr++) { + if (!attrs[attr].isIgnored()) { + String token = tokens[attr]; + if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) { + // update values + if (values[attr] == null) { + values[attr] = new HashSet<>(); + } + values[attr].add(token); + } else { + try { + Double.parseDouble(token); + } catch (NumberFormatException e) { + return false; + } + } + } + } + + return true; + } + + /** + * Loads the data from a file + * + * @param fs + * file system + * @param fpath + * data file path + * @throws IOException + * if any problem is encountered + */ + + public static Data loadData(Dataset dataset, FileSystem fs, Path fpath) throws IOException { + FSDataInputStream input = fs.open(fpath); + Scanner scanner = new Scanner(input, "UTF-8"); + + List<Instance> instances = new ArrayList<>(); + + DataConverter converter = new DataConverter(dataset); + + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + if (!line.isEmpty()) { + Instance instance = converter.convert(line); + if (instance != null) { + instances.add(instance); + } else { + // missing values found + log.warn("{}: missing values", instances.size()); + } + } else { + log.warn("{}: empty string", instances.size()); + } + } + + scanner.close(); + return new Data(dataset, instances); + } + + + /** Loads the data from multiple paths specified by pathes */ + public static Data loadData(Dataset dataset, FileSystem fs, Path[] pathes) throws IOException { + List<Instance> instances = new ArrayList<>(); + + for (Path path : pathes) { + Data loadedData = loadData(dataset, fs, path); + for (int index = 0; index <= loadedData.size(); index++) { + instances.add(loadedData.get(index)); + } + } + return new Data(dataset, instances); + } + + /** Loads the data from a String array */ + public static Data loadData(Dataset dataset, String[] data) { + List<Instance> instances = new ArrayList<>(); + + DataConverter converter = new DataConverter(dataset); + + for (String line : data) { + if (!line.isEmpty()) { + Instance instance = converter.convert(line); + if (instance != null) { + instances.add(instance); + } else { + // missing values found + log.warn("{}: missing values", instances.size()); + } + } else { + log.warn("{}: empty string", instances.size()); + } + } + + return new Data(dataset, instances); + } + + /** + * Generates the Dataset by parsing the entire data + * + * @param descriptor attributes description + * @param regression if true, the label is numerical + * @param fs file system + * @param path data path + */ + public static Dataset generateDataset(CharSequence descriptor, + boolean regression, + FileSystem fs, + Path path) throws DescriptorException, IOException { + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + FSDataInputStream input = fs.open(path); + Scanner scanner = new Scanner(input, "UTF-8"); + + // used to convert CATEGORICAL attribute to Integer + @SuppressWarnings("unchecked") + Set<String>[] valsets = new Set[attrs.length]; + + int size = 0; + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + if (!line.isEmpty()) { + if (parseString(attrs, valsets, line, regression)) { + size++; + } + } + } + + scanner.close(); + + @SuppressWarnings("unchecked") + List<String>[] values = new List[attrs.length]; + for (int i = 0; i < valsets.length; i++) { + if (valsets[i] != null) { + values[i] = Lists.newArrayList(valsets[i]); + } + } + + return new Dataset(attrs, values, size, regression); + } + + /** + * Generates the Dataset by parsing the entire data + * + * @param descriptor + * attributes description + */ + public static Dataset generateDataset(CharSequence descriptor, + boolean regression, + String[] data) throws DescriptorException { + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + // used to convert CATEGORICAL attributes to Integer + @SuppressWarnings("unchecked") + Set<String>[] valsets = new Set[attrs.length]; + + int size = 0; + for (String aData : data) { + if (!aData.isEmpty()) { + if (parseString(attrs, valsets, aData, regression)) { + size++; + } + } + } + + @SuppressWarnings("unchecked") + List<String>[] values = new List[attrs.length]; + for (int i = 0; i < valsets.length; i++) { + if (valsets[i] != null) { + values[i] = Lists.newArrayList(valsets[i]); + } + } + + return new Dataset(attrs, values, size, regression); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java new file mode 100644 index 0000000..0889370 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Helper methods that deals with data lists and arrays of values + */ +@Deprecated +public final class DataUtils { + private DataUtils() { } + + /** + * Computes the sum of the values + * + */ + public static int sum(int[] values) { + int sum = 0; + for (int value : values) { + sum += value; + } + + return sum; + } + + /** + * foreach i : array1[i] += array2[i] + */ + public static void add(int[] array1, int[] array2) { + Preconditions.checkArgument(array1.length == array2.length, "array1.length != array2.length"); + for (int index = 0; index < array1.length; index++) { + array1[index] += array2[index]; + } + } + + /** + * foreach i : array1[i] -= array2[i] + */ + public static void dec(int[] array1, int[] array2) { + Preconditions.checkArgument(array1.length == array2.length, "array1.length != array2.length"); + for (int index = 0; index < array1.length; index++) { + array1[index] -= array2[index]; + } + } + + /** + * return the index of the maximum of the array, breaking ties randomly + * + * @param rng + * used to break ties + * @return index of the maximum + */ + public static int maxindex(Random rng, int[] values) { + int max = 0; + List<Integer> maxindices = new ArrayList<>(); + + for (int index = 0; index < values.length; index++) { + if (values[index] > max) { + max = values[index]; + maxindices.clear(); + maxindices.add(index); + } else if (values[index] == max) { + maxindices.add(index); + } + } + + return maxindices.size() > 1 ? maxindices.get(rng.nextInt(maxindices.size())) : maxindices.get(0); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java new file mode 100644 index 0000000..a392669 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java @@ -0,0 +1,422 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.codehaus.jackson.map.ObjectMapper; +import org.codehaus.jackson.type.TypeReference; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Contains information about the attributes. + */ +@Deprecated +public class Dataset { + + /** + * Attributes type + */ + public enum Attribute { + IGNORED, + NUMERICAL, + CATEGORICAL, + LABEL; + + public boolean isNumerical() { + return this == NUMERICAL; + } + + public boolean isCategorical() { + return this == CATEGORICAL; + } + + public boolean isLabel() { + return this == LABEL; + } + + public boolean isIgnored() { + return this == IGNORED; + } + + private static Attribute fromString(String from) { + Attribute toReturn = LABEL; + if (NUMERICAL.toString().equalsIgnoreCase(from)) { + toReturn = NUMERICAL; + } else if (CATEGORICAL.toString().equalsIgnoreCase(from)) { + toReturn = CATEGORICAL; + } else if (IGNORED.toString().equalsIgnoreCase(from)) { + toReturn = IGNORED; + } + return toReturn; + } + } + + private Attribute[] attributes; + + /** + * list of ignored attributes + */ + private int[] ignored; + + /** + * distinct values (CATEGORIAL attributes only) + */ + private String[][] values; + + /** + * index of the label attribute in the loaded data (without ignored attributed) + */ + private int labelId; + + /** + * number of instances in the dataset + */ + private int nbInstances; + + /** JSON serial/de-serial-izer */ + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + // Some literals for JSON representation + static final String TYPE = "type"; + static final String VALUES = "values"; + static final String LABEL = "label"; + + protected Dataset() {} + + /** + * Should only be called by a DataLoader + * + * @param attrs attributes description + * @param values distinct values for all CATEGORICAL attributes + */ + Dataset(Attribute[] attrs, List<String>[] values, int nbInstances, boolean regression) { + validateValues(attrs, values); + + int nbattrs = countAttributes(attrs); + + // the label values are set apart + attributes = new Attribute[nbattrs]; + this.values = new String[nbattrs][]; + ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs + + labelId = -1; + int ignoredId = 0; + int ind = 0; + for (int attr = 0; attr < attrs.length; attr++) { + if (attrs[attr].isIgnored()) { + ignored[ignoredId++] = attr; + continue; + } + + if (attrs[attr].isLabel()) { + if (labelId != -1) { + throw new IllegalStateException("Label found more than once"); + } + labelId = ind; + if (regression) { + attrs[attr] = Attribute.NUMERICAL; + } else { + attrs[attr] = Attribute.CATEGORICAL; + } + } + + if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) { + this.values[ind] = new String[values[attr].size()]; + values[attr].toArray(this.values[ind]); + } + + attributes[ind++] = attrs[attr]; + } + + if (labelId == -1) { + throw new IllegalStateException("Label not found"); + } + + this.nbInstances = nbInstances; + } + + public int nbValues(int attr) { + return values[attr].length; + } + + public String[] labels() { + return Arrays.copyOf(values[labelId], nblabels()); + } + + public int nblabels() { + return values[labelId].length; + } + + public int getLabelId() { + return labelId; + } + + public double getLabel(Instance instance) { + return instance.get(getLabelId()); + } + + public Attribute getAttribute(int attr) { + return attributes[attr]; + } + + /** + * Returns the code used to represent the label value in the data + * + * @param label label's value to code + * @return label's code + */ + public int labelCode(String label) { + return ArrayUtils.indexOf(values[labelId], label); + } + + /** + * Returns the label value in the data + * This method can be used when the criterion variable is the categorical attribute. + * + * @param code label's code + * @return label's value + */ + public String getLabelString(double code) { + // handle the case (prediction is NaN) + if (Double.isNaN(code)) { + return "unknown"; + } + return values[labelId][(int) code]; + } + + @Override + public String toString() { + return "attributes=" + Arrays.toString(attributes); + } + + /** + * Converts a token to its corresponding integer code for a given attribute + * + * @param attr attribute index + */ + public int valueOf(int attr, String token) { + Preconditions.checkArgument(!isNumerical(attr), "Only for CATEGORICAL attributes"); + Preconditions.checkArgument(values != null, "Values not found (equals null)"); + return ArrayUtils.indexOf(values[attr], token); + } + + public int[] getIgnored() { + return ignored; + } + + /** + * @return number of attributes that are not IGNORED + */ + private static int countAttributes(Attribute[] attrs) { + int nbattrs = 0; + for (Attribute attr : attrs) { + if (!attr.isIgnored()) { + nbattrs++; + } + } + return nbattrs; + } + + private static void validateValues(Attribute[] attrs, List<String>[] values) { + Preconditions.checkArgument(attrs.length == values.length, "attrs.length != values.length"); + for (int attr = 0; attr < attrs.length; attr++) { + Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null, + "values not found for attribute " + attr); + } + } + + /** + * @return number of attributes + */ + public int nbAttributes() { + return attributes.length; + } + + /** + * Is this a numerical attribute ? + * + * @param attr index of the attribute to check + * @return true if the attribute is numerical + */ + public boolean isNumerical(int attr) { + return attributes[attr].isNumerical(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Dataset)) { + return false; + } + + Dataset dataset = (Dataset) obj; + + if (!Arrays.equals(attributes, dataset.attributes)) { + return false; + } + + for (int attr = 0; attr < nbAttributes(); attr++) { + if (!Arrays.equals(values[attr], dataset.values[attr])) { + return false; + } + } + + return labelId == dataset.labelId && nbInstances == dataset.nbInstances; + } + + @Override + public int hashCode() { + int hashCode = labelId + 31 * nbInstances; + for (Attribute attr : attributes) { + hashCode = 31 * hashCode + attr.hashCode(); + } + for (String[] valueRow : values) { + if (valueRow == null) { + continue; + } + for (String value : valueRow) { + hashCode = 31 * hashCode + value.hashCode(); + } + } + return hashCode; + } + + /** + * Loads the dataset from a file + * + * @throws java.io.IOException + */ + public static Dataset load(Configuration conf, Path path) throws IOException { + FileSystem fs = path.getFileSystem(conf); + long bytesToRead = fs.getFileStatus(path).getLen(); + byte[] buff = new byte[Long.valueOf(bytesToRead).intValue()]; + FSDataInputStream input = fs.open(path); + try { + input.readFully(buff); + } finally { + Closeables.close(input, true); + } + String json = new String(buff, Charset.defaultCharset()); + return fromJSON(json); + } + + + /** + * Serialize this instance to JSON + * @return some JSON + */ + public String toJSON() { + List<Map<String, Object>> toWrite = new LinkedList<>(); + // attributes does not include ignored columns and it does include the class label + int ignoredCount = 0; + for (int i = 0; i < attributes.length + ignored.length; i++) { + Map<String, Object> attribute; + int attributesIndex = i - ignoredCount; + if (ignoredCount < ignored.length && i == ignored[ignoredCount]) { + // fill in ignored atttribute + attribute = getMap(Attribute.IGNORED, null, false); + ignoredCount++; + } else if (attributesIndex == labelId) { + // fill in the label + attribute = getMap(attributes[attributesIndex], values[attributesIndex], true); + } else { + // normal attribute + attribute = getMap(attributes[attributesIndex], values[attributesIndex], false); + } + toWrite.add(attribute); + } + try { + return OBJECT_MAPPER.writeValueAsString(toWrite); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + /** + * De-serialize an instance from a string + * @param json From which an instance is created + * @return A shiny new Dataset + */ + public static Dataset fromJSON(String json) { + List<Map<String, Object>> fromJSON; + try { + fromJSON = OBJECT_MAPPER.readValue(json, new TypeReference<List<Map<String, Object>>>() {}); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + List<Attribute> attributes = new LinkedList<>(); + List<Integer> ignored = new LinkedList<>(); + String[][] nominalValues = new String[fromJSON.size()][]; + Dataset dataset = new Dataset(); + for (int i = 0; i < fromJSON.size(); i++) { + Map<String, Object> attribute = fromJSON.get(i); + if (Attribute.fromString((String) attribute.get(TYPE)) == Attribute.IGNORED) { + ignored.add(i); + } else { + Attribute asAttribute = Attribute.fromString((String) attribute.get(TYPE)); + attributes.add(asAttribute); + if ((Boolean) attribute.get(LABEL)) { + dataset.labelId = i - ignored.size(); + } + if (attribute.get(VALUES) != null) { + List<String> get = (List<String>) attribute.get(VALUES); + String[] array = get.toArray(new String[get.size()]); + nominalValues[i - ignored.size()] = array; + } + } + } + dataset.attributes = attributes.toArray(new Attribute[attributes.size()]); + dataset.ignored = new int[ignored.size()]; + dataset.values = nominalValues; + for (int i = 0; i < dataset.ignored.length; i++) { + dataset.ignored[i] = ignored.get(i); + } + return dataset; + } + + /** + * Generate a map to describe an attribute + * @param type The type + * @param values - values + * @param isLabel - is a label + * @return map of (AttributeTypes, Values) + */ + private Map<String, Object> getMap(Attribute type, String[] values, boolean isLabel) { + Map<String, Object> attribute = new HashMap<>(); + attribute.put(TYPE, type.toString().toLowerCase(Locale.getDefault())); + attribute.put(VALUES, values); + attribute.put(LABEL, isLabel); + return attribute; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java new file mode 100644 index 0000000..e7a10ff --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +/** + * Exception thrown when parsing a descriptor + */ +@Deprecated +public class DescriptorException extends Exception { + public DescriptorException(String msg) { + super(msg); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java new file mode 100644 index 0000000..aadedbd --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java @@ -0,0 +1,110 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Splitter; +import org.apache.mahout.classifier.df.data.Dataset.Attribute; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +/** + * Contains various methods that deal with descriptor strings + */ +@Deprecated +public final class DescriptorUtils { + + private static final Splitter SPACE = Splitter.on(' ').omitEmptyStrings(); + + private DescriptorUtils() { } + + /** + * Parses a descriptor string and generates the corresponding array of Attributes + * + * @throws DescriptorException + * if a bad token is encountered + */ + public static Attribute[] parseDescriptor(CharSequence descriptor) throws DescriptorException { + List<Attribute> attributes = new ArrayList<>(); + for (String token : SPACE.split(descriptor)) { + token = token.toUpperCase(Locale.ENGLISH); + if ("I".equals(token)) { + attributes.add(Attribute.IGNORED); + } else if ("N".equals(token)) { + attributes.add(Attribute.NUMERICAL); + } else if ("C".equals(token)) { + attributes.add(Attribute.CATEGORICAL); + } else if ("L".equals(token)) { + attributes.add(Attribute.LABEL); + } else { + throw new DescriptorException("Bad Token : " + token); + } + } + return attributes.toArray(new Attribute[attributes.size()]); + } + + /** + * Generates a valid descriptor string from a user-friendly representation.<br> + * for example "3 N I N N 2 C L 5 I" generates "N N N I N N C C L I I I I I".<br> + * this useful when describing datasets with a large number of attributes + * @throws DescriptorException + */ + public static String generateDescriptor(CharSequence description) throws DescriptorException { + return generateDescriptor(SPACE.split(description)); + } + + /** + * Generates a valid descriptor string from a list of tokens + * @throws DescriptorException + */ + public static String generateDescriptor(Iterable<String> tokens) throws DescriptorException { + StringBuilder descriptor = new StringBuilder(); + + int multiplicator = 0; + + for (String token : tokens) { + try { + // try to parse an integer + int number = Integer.parseInt(token); + + if (number <= 0) { + throw new DescriptorException("Multiplicator (" + number + ") must be > 0"); + } + if (multiplicator > 0) { + throw new DescriptorException("A multiplicator cannot be followed by another multiplicator"); + } + + multiplicator = number; + } catch (NumberFormatException e) { + // token is not a number + if (multiplicator == 0) { + multiplicator = 1; + } + + for (int index = 0; index < multiplicator; index++) { + descriptor.append(token).append(' '); + } + + multiplicator = 0; + } + } + + return descriptor.toString().trim(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java new file mode 100644 index 0000000..6a23cb8 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import org.apache.mahout.math.Vector; + +/** + * Represents one data instance. + */ +@Deprecated +public class Instance { + + /** attributes, except LABEL and IGNORED */ + private final Vector attrs; + + public Instance(Vector attrs) { + this.attrs = attrs; + } + + /** + * Return the attribute at the specified position + * + * @param index + * position of the attribute to retrieve + * @return value of the attribute + */ + public double get(int index) { + return attrs.getQuick(index); + } + + /** + * Set the value at the given index + * + * @param value + * a double value to set + */ + public void set(int index, double value) { + attrs.set(index, value); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Instance)) { + return false; + } + + Instance instance = (Instance) obj; + + return /*id == instance.id &&*/ attrs.equals(instance.attrs); + + } + + @Override + public int hashCode() { + return /*id +*/ attrs.hashCode(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java new file mode 100644 index 0000000..c16ca3f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * Condition on Instance + */ +@Deprecated +public abstract class Condition { + + /** + * Returns true is the checked instance matches the condition + * + * @param instance + * checked instance + * @return true is the checked instance matches the condition + */ + public abstract boolean isTrueFor(Instance instance); + + /** + * Condition that checks if the given attribute has a value "equal" to the given value + */ + public static Condition equals(int attr, double value) { + return new Equals(attr, value); + } + + /** + * Condition that checks if the given attribute has a value "lesser" than the given value + */ + public static Condition lesser(int attr, double value) { + return new Lesser(attr, value); + } + + /** + * Condition that checks if the given attribute has a value "greater or equal" than the given value + */ + public static Condition greaterOrEquals(int attr, double value) { + return new GreaterOrEquals(attr, value); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java new file mode 100644 index 0000000..c51082b --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * True if a given attribute has a given value + */ +@Deprecated +public class Equals extends Condition { + + private final int attr; + + private final double value; + + public Equals(int attr, double value) { + this.attr = attr; + this.value = value; + } + + @Override + public boolean isTrueFor(Instance instance) { + return instance.get(attr) == value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java new file mode 100644 index 0000000..3e3d1a4 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * True if a given attribute has a value "greater or equal" than a given value + */ +@Deprecated +public class GreaterOrEquals extends Condition { + + private final int attr; + + private final double value; + + public GreaterOrEquals(int attr, double value) { + this.attr = attr; + this.value = value; + } + + @Override + public boolean isTrueFor(Instance v) { + return v.get(attr) >= value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java new file mode 100644 index 0000000..577cb24 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * True if a given attribute has a value "lesser" than a given value + */ +@Deprecated +public class Lesser extends Condition { + + private final int attr; + + private final double value; + + public Lesser(int attr, double value) { + this.attr = attr; + this.value = value; + } + + @Override + public boolean isTrueFor(Instance instance) { + return instance.get(attr) < value; + } + +}
