http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java new file mode 100644 index 0000000..bf686a4 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.builder; + +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.node.Node; + +import java.util.Random; + +/** + * Abstract base class for TreeBuilders + */ +@Deprecated +public interface TreeBuilder { + + /** + * Builds a Decision tree using the training data + * + * @param rng + * random-numbers generator + * @param data + * training data + * @return root Node + */ + Node build(Random rng, Data data); + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java new file mode 100644 index 0000000..77e5ed5 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java @@ -0,0 +1,281 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import org.apache.mahout.classifier.df.data.conditions.Condition; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Random; + +/** + * Holds a list of vectors and their corresponding Dataset. contains various operations that deals with the + * vectors (subset, count,...) + * + */ +@Deprecated +public class Data implements Cloneable { + + private final List<Instance> instances; + + private final Dataset dataset; + + public Data(Dataset dataset) { + this.dataset = dataset; + this.instances = new ArrayList<>(); + } + + public Data(Dataset dataset, List<Instance> instances) { + this.dataset = dataset; + this.instances = new ArrayList<>(instances); + } + + /** + * @return the number of elements + */ + public int size() { + return instances.size(); + } + + /** + * @return true if this data contains no element + */ + public boolean isEmpty() { + return instances.isEmpty(); + } + + /** + * @param v + * element whose presence in this list if to be searched + * @return true is this data contains the specified element. + */ + public boolean contains(Instance v) { + return instances.contains(v); + } + + /** + * Returns the element at the specified position + * + * @param index + * index of element to return + * @return the element at the specified position + * @throws IndexOutOfBoundsException + * if the index is out of range + */ + public Instance get(int index) { + return instances.get(index); + } + + /** + * @return the subset from this data that matches the given condition + */ + public Data subset(Condition condition) { + List<Instance> subset = new ArrayList<>(); + + for (Instance instance : instances) { + if (condition.isTrueFor(instance)) { + subset.add(instance); + } + } + + return new Data(dataset, subset); + } + + /** + * if data has N cases, sample N cases at random -but with replacement. + */ + public Data bagging(Random rng) { + int datasize = size(); + List<Instance> bag = new ArrayList<>(datasize); + + for (int i = 0; i < datasize; i++) { + bag.add(instances.get(rng.nextInt(datasize))); + } + + return new Data(dataset, bag); + } + + /** + * if data has N cases, sample N cases at random -but with replacement. + * + * @param sampled + * indicating which instance has been sampled + * + * @return sampled data + */ + public Data bagging(Random rng, boolean[] sampled) { + int datasize = size(); + List<Instance> bag = new ArrayList<>(datasize); + + for (int i = 0; i < datasize; i++) { + int index = rng.nextInt(datasize); + bag.add(instances.get(index)); + sampled[index] = true; + } + + return new Data(dataset, bag); + } + + /** + * Splits the data in two, returns one part, and this gets the rest of the data. <b>VERY SLOW!</b> + */ + public Data rsplit(Random rng, int subsize) { + List<Instance> subset = new ArrayList<>(subsize); + + for (int i = 0; i < subsize; i++) { + subset.add(instances.remove(rng.nextInt(instances.size()))); + } + + return new Data(dataset, subset); + } + + /** + * checks if all the vectors have identical attribute values + * + * @return true is all the vectors are identical or the data is empty<br> + * false otherwise + */ + public boolean isIdentical() { + if (isEmpty()) { + return true; + } + + Instance instance = get(0); + for (int attr = 0; attr < dataset.nbAttributes(); attr++) { + for (int index = 1; index < size(); index++) { + if (get(index).get(attr) != instance.get(attr)) { + return false; + } + } + } + + return true; + } + + /** + * checks if all the vectors have identical label values + */ + public boolean identicalLabel() { + if (isEmpty()) { + return true; + } + + double label = dataset.getLabel(get(0)); + for (int index = 1; index < size(); index++) { + if (dataset.getLabel(get(index)) != label) { + return false; + } + } + + return true; + } + + /** + * finds all distinct values of a given attribute + */ + public double[] values(int attr) { + Collection<Double> result = new HashSet<>(); + + for (Instance instance : instances) { + result.add(instance.get(attr)); + } + + double[] values = new double[result.size()]; + + int index = 0; + for (Double value : result) { + values[index++] = value; + } + + return values; + } + + @Override + public Data clone() { + return new Data(dataset, new ArrayList<>(instances)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Data)) { + return false; + } + + Data data = (Data) obj; + + return instances.equals(data.instances) && dataset.equals(data.dataset); + } + + @Override + public int hashCode() { + return instances.hashCode() + dataset.hashCode(); + } + + /** + * extract the labels of all instances + */ + public double[] extractLabels() { + double[] labels = new double[size()]; + + for (int index = 0; index < labels.length; index++) { + labels[index] = dataset.getLabel(get(index)); + } + + return labels; + } + + /** + * finds the majority label, breaking ties randomly<br> + * This method can be used when the criterion variable is the categorical attribute. + * + * @return the majority label value + */ + public int majorityLabel(Random rng) { + // count the frequency of each label value + int[] counts = new int[dataset.nblabels()]; + + for (int index = 0; index < size(); index++) { + counts[(int) dataset.getLabel(get(index))]++; + } + + // find the label values that appears the most + return DataUtils.maxindex(rng, counts); + } + + /** + * Counts the number of occurrences of each label value<br> + * This method can be used when the criterion variable is the categorical attribute. + * + * @param counts + * will contain the results, supposed to be initialized at 0 + */ + public void countLabels(int[] counts) { + for (int index = 0; index < size(); index++) { + counts[(int) dataset.getLabel(get(index))]++; + } + } + + public Dataset getDataset() { + return dataset; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java new file mode 100644 index 0000000..f1bdc95 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.mahout.math.DenseVector; + +import java.util.regex.Pattern; + +/** + * Converts String to Instance using a Dataset + */ +@Deprecated +public class DataConverter { + + private static final Pattern COMMA_SPACE = Pattern.compile("[, ]"); + + private final Dataset dataset; + + public DataConverter(Dataset dataset) { + this.dataset = dataset; + } + + public Instance convert(CharSequence string) { + // all attributes (categorical, numerical, label), ignored + int nball = dataset.nbAttributes() + dataset.getIgnored().length; + + String[] tokens = COMMA_SPACE.split(string); + Preconditions.checkArgument(tokens.length == nball, + "Wrong number of attributes in the string: " + tokens.length + ". Must be " + nball); + + int nbattrs = dataset.nbAttributes(); + DenseVector vector = new DenseVector(nbattrs); + + int aId = 0; + for (int attr = 0; attr < nball; attr++) { + if (!ArrayUtils.contains(dataset.getIgnored(), attr)) { + String token = tokens[attr].trim(); + + if ("?".equals(token)) { + // missing value + return null; + } + + if (dataset.isNumerical(aId)) { + vector.set(aId++, Double.parseDouble(token)); + } else { // CATEGORICAL + vector.set(aId, dataset.valueOf(aId, token)); + aId++; + } + } + } + + return new Instance(vector); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java new file mode 100644 index 0000000..c62dcac --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java @@ -0,0 +1,255 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.classifier.df.data.Dataset.Attribute; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Scanner; +import java.util.Set; +import java.util.regex.Pattern; + +/** + * Converts the input data to a Vector Array using the information given by the Dataset.<br> + * Generates for each line a Vector that contains :<br> + * <ul> + * <li>double parsed value for NUMERICAL attributes</li> + * <li>int value for CATEGORICAL and LABEL attributes</li> + * </ul> + * <br> + * adds an IGNORED first attribute that will contain a unique id for each instance, which is the line number + * of the instance in the input data + */ +@Deprecated +public final class DataLoader { + + private static final Logger log = LoggerFactory.getLogger(DataLoader.class); + + private static final Pattern SEPARATORS = Pattern.compile("[, ]"); + + private DataLoader() {} + + /** + * Converts a comma-separated String to a Vector. + * + * @param attrs + * attributes description + * @param values + * used to convert CATEGORICAL attribute values to Integer + * @return false if there are missing values '?' or NUMERICAL attribute values is not numeric + */ + private static boolean parseString(Attribute[] attrs, Set<String>[] values, CharSequence string, + boolean regression) { + String[] tokens = SEPARATORS.split(string); + Preconditions.checkArgument(tokens.length == attrs.length, + "Wrong number of attributes in the string: " + tokens.length + ". Must be: " + attrs.length); + + // extract tokens and check is there is any missing value + for (int attr = 0; attr < attrs.length; attr++) { + if (!attrs[attr].isIgnored() && "?".equals(tokens[attr])) { + return false; // missing value + } + } + + for (int attr = 0; attr < attrs.length; attr++) { + if (!attrs[attr].isIgnored()) { + String token = tokens[attr]; + if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) { + // update values + if (values[attr] == null) { + values[attr] = new HashSet<>(); + } + values[attr].add(token); + } else { + try { + Double.parseDouble(token); + } catch (NumberFormatException e) { + return false; + } + } + } + } + + return true; + } + + /** + * Loads the data from a file + * + * @param fs + * file system + * @param fpath + * data file path + * @throws IOException + * if any problem is encountered + */ + + public static Data loadData(Dataset dataset, FileSystem fs, Path fpath) throws IOException { + FSDataInputStream input = fs.open(fpath); + Scanner scanner = new Scanner(input, "UTF-8"); + + List<Instance> instances = new ArrayList<>(); + + DataConverter converter = new DataConverter(dataset); + + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + if (!line.isEmpty()) { + Instance instance = converter.convert(line); + if (instance != null) { + instances.add(instance); + } else { + // missing values found + log.warn("{}: missing values", instances.size()); + } + } else { + log.warn("{}: empty string", instances.size()); + } + } + + scanner.close(); + return new Data(dataset, instances); + } + + + /** Loads the data from multiple paths specified by pathes */ + public static Data loadData(Dataset dataset, FileSystem fs, Path[] pathes) throws IOException { + List<Instance> instances = new ArrayList<>(); + + for (Path path : pathes) { + Data loadedData = loadData(dataset, fs, path); + for (int index = 0; index <= loadedData.size(); index++) { + instances.add(loadedData.get(index)); + } + } + return new Data(dataset, instances); + } + + /** Loads the data from a String array */ + public static Data loadData(Dataset dataset, String[] data) { + List<Instance> instances = new ArrayList<>(); + + DataConverter converter = new DataConverter(dataset); + + for (String line : data) { + if (!line.isEmpty()) { + Instance instance = converter.convert(line); + if (instance != null) { + instances.add(instance); + } else { + // missing values found + log.warn("{}: missing values", instances.size()); + } + } else { + log.warn("{}: empty string", instances.size()); + } + } + + return new Data(dataset, instances); + } + + /** + * Generates the Dataset by parsing the entire data + * + * @param descriptor attributes description + * @param regression if true, the label is numerical + * @param fs file system + * @param path data path + */ + public static Dataset generateDataset(CharSequence descriptor, + boolean regression, + FileSystem fs, + Path path) throws DescriptorException, IOException { + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + FSDataInputStream input = fs.open(path); + Scanner scanner = new Scanner(input, "UTF-8"); + + // used to convert CATEGORICAL attribute to Integer + @SuppressWarnings("unchecked") + Set<String>[] valsets = new Set[attrs.length]; + + int size = 0; + while (scanner.hasNextLine()) { + String line = scanner.nextLine(); + if (!line.isEmpty()) { + if (parseString(attrs, valsets, line, regression)) { + size++; + } + } + } + + scanner.close(); + + @SuppressWarnings("unchecked") + List<String>[] values = new List[attrs.length]; + for (int i = 0; i < valsets.length; i++) { + if (valsets[i] != null) { + values[i] = Lists.newArrayList(valsets[i]); + } + } + + return new Dataset(attrs, values, size, regression); + } + + /** + * Generates the Dataset by parsing the entire data + * + * @param descriptor + * attributes description + */ + public static Dataset generateDataset(CharSequence descriptor, + boolean regression, + String[] data) throws DescriptorException { + Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor); + + // used to convert CATEGORICAL attributes to Integer + @SuppressWarnings("unchecked") + Set<String>[] valsets = new Set[attrs.length]; + + int size = 0; + for (String aData : data) { + if (!aData.isEmpty()) { + if (parseString(attrs, valsets, aData, regression)) { + size++; + } + } + } + + @SuppressWarnings("unchecked") + List<String>[] values = new List[attrs.length]; + for (int i = 0; i < valsets.length; i++) { + if (valsets[i] != null) { + values[i] = Lists.newArrayList(valsets[i]); + } + } + + return new Dataset(attrs, values, size, regression); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java new file mode 100644 index 0000000..0889370 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Helper methods that deals with data lists and arrays of values + */ +@Deprecated +public final class DataUtils { + private DataUtils() { } + + /** + * Computes the sum of the values + * + */ + public static int sum(int[] values) { + int sum = 0; + for (int value : values) { + sum += value; + } + + return sum; + } + + /** + * foreach i : array1[i] += array2[i] + */ + public static void add(int[] array1, int[] array2) { + Preconditions.checkArgument(array1.length == array2.length, "array1.length != array2.length"); + for (int index = 0; index < array1.length; index++) { + array1[index] += array2[index]; + } + } + + /** + * foreach i : array1[i] -= array2[i] + */ + public static void dec(int[] array1, int[] array2) { + Preconditions.checkArgument(array1.length == array2.length, "array1.length != array2.length"); + for (int index = 0; index < array1.length; index++) { + array1[index] -= array2[index]; + } + } + + /** + * return the index of the maximum of the array, breaking ties randomly + * + * @param rng + * used to break ties + * @return index of the maximum + */ + public static int maxindex(Random rng, int[] values) { + int max = 0; + List<Integer> maxindices = new ArrayList<>(); + + for (int index = 0; index < values.length; index++) { + if (values[index] > max) { + max = values[index]; + maxindices.clear(); + maxindices.add(index); + } else if (values[index] == max) { + maxindices.add(index); + } + } + + return maxindices.size() > 1 ? maxindices.get(rng.nextInt(maxindices.size())) : maxindices.get(0); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java new file mode 100644 index 0000000..a392669 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java @@ -0,0 +1,422 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.codehaus.jackson.map.ObjectMapper; +import org.codehaus.jackson.type.TypeReference; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * Contains information about the attributes. + */ +@Deprecated +public class Dataset { + + /** + * Attributes type + */ + public enum Attribute { + IGNORED, + NUMERICAL, + CATEGORICAL, + LABEL; + + public boolean isNumerical() { + return this == NUMERICAL; + } + + public boolean isCategorical() { + return this == CATEGORICAL; + } + + public boolean isLabel() { + return this == LABEL; + } + + public boolean isIgnored() { + return this == IGNORED; + } + + private static Attribute fromString(String from) { + Attribute toReturn = LABEL; + if (NUMERICAL.toString().equalsIgnoreCase(from)) { + toReturn = NUMERICAL; + } else if (CATEGORICAL.toString().equalsIgnoreCase(from)) { + toReturn = CATEGORICAL; + } else if (IGNORED.toString().equalsIgnoreCase(from)) { + toReturn = IGNORED; + } + return toReturn; + } + } + + private Attribute[] attributes; + + /** + * list of ignored attributes + */ + private int[] ignored; + + /** + * distinct values (CATEGORIAL attributes only) + */ + private String[][] values; + + /** + * index of the label attribute in the loaded data (without ignored attributed) + */ + private int labelId; + + /** + * number of instances in the dataset + */ + private int nbInstances; + + /** JSON serial/de-serial-izer */ + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + // Some literals for JSON representation + static final String TYPE = "type"; + static final String VALUES = "values"; + static final String LABEL = "label"; + + protected Dataset() {} + + /** + * Should only be called by a DataLoader + * + * @param attrs attributes description + * @param values distinct values for all CATEGORICAL attributes + */ + Dataset(Attribute[] attrs, List<String>[] values, int nbInstances, boolean regression) { + validateValues(attrs, values); + + int nbattrs = countAttributes(attrs); + + // the label values are set apart + attributes = new Attribute[nbattrs]; + this.values = new String[nbattrs][]; + ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs + + labelId = -1; + int ignoredId = 0; + int ind = 0; + for (int attr = 0; attr < attrs.length; attr++) { + if (attrs[attr].isIgnored()) { + ignored[ignoredId++] = attr; + continue; + } + + if (attrs[attr].isLabel()) { + if (labelId != -1) { + throw new IllegalStateException("Label found more than once"); + } + labelId = ind; + if (regression) { + attrs[attr] = Attribute.NUMERICAL; + } else { + attrs[attr] = Attribute.CATEGORICAL; + } + } + + if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) { + this.values[ind] = new String[values[attr].size()]; + values[attr].toArray(this.values[ind]); + } + + attributes[ind++] = attrs[attr]; + } + + if (labelId == -1) { + throw new IllegalStateException("Label not found"); + } + + this.nbInstances = nbInstances; + } + + public int nbValues(int attr) { + return values[attr].length; + } + + public String[] labels() { + return Arrays.copyOf(values[labelId], nblabels()); + } + + public int nblabels() { + return values[labelId].length; + } + + public int getLabelId() { + return labelId; + } + + public double getLabel(Instance instance) { + return instance.get(getLabelId()); + } + + public Attribute getAttribute(int attr) { + return attributes[attr]; + } + + /** + * Returns the code used to represent the label value in the data + * + * @param label label's value to code + * @return label's code + */ + public int labelCode(String label) { + return ArrayUtils.indexOf(values[labelId], label); + } + + /** + * Returns the label value in the data + * This method can be used when the criterion variable is the categorical attribute. + * + * @param code label's code + * @return label's value + */ + public String getLabelString(double code) { + // handle the case (prediction is NaN) + if (Double.isNaN(code)) { + return "unknown"; + } + return values[labelId][(int) code]; + } + + @Override + public String toString() { + return "attributes=" + Arrays.toString(attributes); + } + + /** + * Converts a token to its corresponding integer code for a given attribute + * + * @param attr attribute index + */ + public int valueOf(int attr, String token) { + Preconditions.checkArgument(!isNumerical(attr), "Only for CATEGORICAL attributes"); + Preconditions.checkArgument(values != null, "Values not found (equals null)"); + return ArrayUtils.indexOf(values[attr], token); + } + + public int[] getIgnored() { + return ignored; + } + + /** + * @return number of attributes that are not IGNORED + */ + private static int countAttributes(Attribute[] attrs) { + int nbattrs = 0; + for (Attribute attr : attrs) { + if (!attr.isIgnored()) { + nbattrs++; + } + } + return nbattrs; + } + + private static void validateValues(Attribute[] attrs, List<String>[] values) { + Preconditions.checkArgument(attrs.length == values.length, "attrs.length != values.length"); + for (int attr = 0; attr < attrs.length; attr++) { + Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null, + "values not found for attribute " + attr); + } + } + + /** + * @return number of attributes + */ + public int nbAttributes() { + return attributes.length; + } + + /** + * Is this a numerical attribute ? + * + * @param attr index of the attribute to check + * @return true if the attribute is numerical + */ + public boolean isNumerical(int attr) { + return attributes[attr].isNumerical(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Dataset)) { + return false; + } + + Dataset dataset = (Dataset) obj; + + if (!Arrays.equals(attributes, dataset.attributes)) { + return false; + } + + for (int attr = 0; attr < nbAttributes(); attr++) { + if (!Arrays.equals(values[attr], dataset.values[attr])) { + return false; + } + } + + return labelId == dataset.labelId && nbInstances == dataset.nbInstances; + } + + @Override + public int hashCode() { + int hashCode = labelId + 31 * nbInstances; + for (Attribute attr : attributes) { + hashCode = 31 * hashCode + attr.hashCode(); + } + for (String[] valueRow : values) { + if (valueRow == null) { + continue; + } + for (String value : valueRow) { + hashCode = 31 * hashCode + value.hashCode(); + } + } + return hashCode; + } + + /** + * Loads the dataset from a file + * + * @throws java.io.IOException + */ + public static Dataset load(Configuration conf, Path path) throws IOException { + FileSystem fs = path.getFileSystem(conf); + long bytesToRead = fs.getFileStatus(path).getLen(); + byte[] buff = new byte[Long.valueOf(bytesToRead).intValue()]; + FSDataInputStream input = fs.open(path); + try { + input.readFully(buff); + } finally { + Closeables.close(input, true); + } + String json = new String(buff, Charset.defaultCharset()); + return fromJSON(json); + } + + + /** + * Serialize this instance to JSON + * @return some JSON + */ + public String toJSON() { + List<Map<String, Object>> toWrite = new LinkedList<>(); + // attributes does not include ignored columns and it does include the class label + int ignoredCount = 0; + for (int i = 0; i < attributes.length + ignored.length; i++) { + Map<String, Object> attribute; + int attributesIndex = i - ignoredCount; + if (ignoredCount < ignored.length && i == ignored[ignoredCount]) { + // fill in ignored atttribute + attribute = getMap(Attribute.IGNORED, null, false); + ignoredCount++; + } else if (attributesIndex == labelId) { + // fill in the label + attribute = getMap(attributes[attributesIndex], values[attributesIndex], true); + } else { + // normal attribute + attribute = getMap(attributes[attributesIndex], values[attributesIndex], false); + } + toWrite.add(attribute); + } + try { + return OBJECT_MAPPER.writeValueAsString(toWrite); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + } + + /** + * De-serialize an instance from a string + * @param json From which an instance is created + * @return A shiny new Dataset + */ + public static Dataset fromJSON(String json) { + List<Map<String, Object>> fromJSON; + try { + fromJSON = OBJECT_MAPPER.readValue(json, new TypeReference<List<Map<String, Object>>>() {}); + } catch (Exception ex) { + throw new RuntimeException(ex); + } + List<Attribute> attributes = new LinkedList<>(); + List<Integer> ignored = new LinkedList<>(); + String[][] nominalValues = new String[fromJSON.size()][]; + Dataset dataset = new Dataset(); + for (int i = 0; i < fromJSON.size(); i++) { + Map<String, Object> attribute = fromJSON.get(i); + if (Attribute.fromString((String) attribute.get(TYPE)) == Attribute.IGNORED) { + ignored.add(i); + } else { + Attribute asAttribute = Attribute.fromString((String) attribute.get(TYPE)); + attributes.add(asAttribute); + if ((Boolean) attribute.get(LABEL)) { + dataset.labelId = i - ignored.size(); + } + if (attribute.get(VALUES) != null) { + List<String> get = (List<String>) attribute.get(VALUES); + String[] array = get.toArray(new String[get.size()]); + nominalValues[i - ignored.size()] = array; + } + } + } + dataset.attributes = attributes.toArray(new Attribute[attributes.size()]); + dataset.ignored = new int[ignored.size()]; + dataset.values = nominalValues; + for (int i = 0; i < dataset.ignored.length; i++) { + dataset.ignored[i] = ignored.get(i); + } + return dataset; + } + + /** + * Generate a map to describe an attribute + * @param type The type + * @param values - values + * @param isLabel - is a label + * @return map of (AttributeTypes, Values) + */ + private Map<String, Object> getMap(Attribute type, String[] values, boolean isLabel) { + Map<String, Object> attribute = new HashMap<>(); + attribute.put(TYPE, type.toString().toLowerCase(Locale.getDefault())); + attribute.put(VALUES, values); + attribute.put(LABEL, isLabel); + return attribute; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java new file mode 100644 index 0000000..e7a10ff --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +/** + * Exception thrown when parsing a descriptor + */ +@Deprecated +public class DescriptorException extends Exception { + public DescriptorException(String msg) { + super(msg); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java new file mode 100644 index 0000000..aadedbd --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java @@ -0,0 +1,110 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import com.google.common.base.Splitter; +import org.apache.mahout.classifier.df.data.Dataset.Attribute; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +/** + * Contains various methods that deal with descriptor strings + */ +@Deprecated +public final class DescriptorUtils { + + private static final Splitter SPACE = Splitter.on(' ').omitEmptyStrings(); + + private DescriptorUtils() { } + + /** + * Parses a descriptor string and generates the corresponding array of Attributes + * + * @throws DescriptorException + * if a bad token is encountered + */ + public static Attribute[] parseDescriptor(CharSequence descriptor) throws DescriptorException { + List<Attribute> attributes = new ArrayList<>(); + for (String token : SPACE.split(descriptor)) { + token = token.toUpperCase(Locale.ENGLISH); + if ("I".equals(token)) { + attributes.add(Attribute.IGNORED); + } else if ("N".equals(token)) { + attributes.add(Attribute.NUMERICAL); + } else if ("C".equals(token)) { + attributes.add(Attribute.CATEGORICAL); + } else if ("L".equals(token)) { + attributes.add(Attribute.LABEL); + } else { + throw new DescriptorException("Bad Token : " + token); + } + } + return attributes.toArray(new Attribute[attributes.size()]); + } + + /** + * Generates a valid descriptor string from a user-friendly representation.<br> + * for example "3 N I N N 2 C L 5 I" generates "N N N I N N C C L I I I I I".<br> + * this useful when describing datasets with a large number of attributes + * @throws DescriptorException + */ + public static String generateDescriptor(CharSequence description) throws DescriptorException { + return generateDescriptor(SPACE.split(description)); + } + + /** + * Generates a valid descriptor string from a list of tokens + * @throws DescriptorException + */ + public static String generateDescriptor(Iterable<String> tokens) throws DescriptorException { + StringBuilder descriptor = new StringBuilder(); + + int multiplicator = 0; + + for (String token : tokens) { + try { + // try to parse an integer + int number = Integer.parseInt(token); + + if (number <= 0) { + throw new DescriptorException("Multiplicator (" + number + ") must be > 0"); + } + if (multiplicator > 0) { + throw new DescriptorException("A multiplicator cannot be followed by another multiplicator"); + } + + multiplicator = number; + } catch (NumberFormatException e) { + // token is not a number + if (multiplicator == 0) { + multiplicator = 1; + } + + for (int index = 0; index < multiplicator; index++) { + descriptor.append(token).append(' '); + } + + multiplicator = 0; + } + } + + return descriptor.toString().trim(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java new file mode 100644 index 0000000..6a23cb8 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data; + +import org.apache.mahout.math.Vector; + +/** + * Represents one data instance. + */ +@Deprecated +public class Instance { + + /** attributes, except LABEL and IGNORED */ + private final Vector attrs; + + public Instance(Vector attrs) { + this.attrs = attrs; + } + + /** + * Return the attribute at the specified position + * + * @param index + * position of the attribute to retrieve + * @return value of the attribute + */ + public double get(int index) { + return attrs.getQuick(index); + } + + /** + * Set the value at the given index + * + * @param value + * a double value to set + */ + public void set(int index, double value) { + attrs.set(index, value); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof Instance)) { + return false; + } + + Instance instance = (Instance) obj; + + return /*id == instance.id &&*/ attrs.equals(instance.attrs); + + } + + @Override + public int hashCode() { + return /*id +*/ attrs.hashCode(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java new file mode 100644 index 0000000..c16ca3f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * Condition on Instance + */ +@Deprecated +public abstract class Condition { + + /** + * Returns true is the checked instance matches the condition + * + * @param instance + * checked instance + * @return true is the checked instance matches the condition + */ + public abstract boolean isTrueFor(Instance instance); + + /** + * Condition that checks if the given attribute has a value "equal" to the given value + */ + public static Condition equals(int attr, double value) { + return new Equals(attr, value); + } + + /** + * Condition that checks if the given attribute has a value "lesser" than the given value + */ + public static Condition lesser(int attr, double value) { + return new Lesser(attr, value); + } + + /** + * Condition that checks if the given attribute has a value "greater or equal" than the given value + */ + public static Condition greaterOrEquals(int attr, double value) { + return new GreaterOrEquals(attr, value); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java new file mode 100644 index 0000000..c51082b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * True if a given attribute has a given value + */ +@Deprecated +public class Equals extends Condition { + + private final int attr; + + private final double value; + + public Equals(int attr, double value) { + this.attr = attr; + this.value = value; + } + + @Override + public boolean isTrueFor(Instance instance) { + return instance.get(attr) == value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java new file mode 100644 index 0000000..3e3d1a4 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * True if a given attribute has a value "greater or equal" than a given value + */ +@Deprecated +public class GreaterOrEquals extends Condition { + + private final int attr; + + private final double value; + + public GreaterOrEquals(int attr, double value) { + this.attr = attr; + this.value = value; + } + + @Override + public boolean isTrueFor(Instance v) { + return v.get(attr) >= value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java new file mode 100644 index 0000000..577cb24 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.data.conditions; + +import org.apache.mahout.classifier.df.data.Instance; + +/** + * True if a given attribute has a value "lesser" than a given value + */ +@Deprecated +public class Lesser extends Condition { + + private final int attr; + + private final double value; + + public Lesser(int attr, double value) { + this.attr = attr; + this.value = value; + } + + @Override + public boolean isTrueFor(Instance instance) { + return instance.get(attr) < value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java new file mode 100644 index 0000000..32d7b5c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java @@ -0,0 +1,333 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.mapreduce.InputSplit; +import org.apache.hadoop.mapreduce.Job; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.StringUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Comparator; + +/** + * Base class for Mapred DecisionForest builders. Takes care of storing the parameters common to the mapred + * implementations.<br> + * The child classes must implement at least : + * <ul> + * <li>void configureJob(Job) : to further configure the job before its launch; and</li> + * <li>DecisionForest parseOutput(Job, PredictionCallback) : in order to convert the job outputs into a + * DecisionForest and its corresponding oob predictions</li> + * </ul> + * + */ +@Deprecated +public abstract class Builder { + + private static final Logger log = LoggerFactory.getLogger(Builder.class); + + private final TreeBuilder treeBuilder; + private final Path dataPath; + private final Path datasetPath; + private final Long seed; + private final Configuration conf; + private String outputDirName = "output"; + + protected Builder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed, Configuration conf) { + this.treeBuilder = treeBuilder; + this.dataPath = dataPath; + this.datasetPath = datasetPath; + this.seed = seed; + this.conf = new Configuration(conf); + } + + protected Path getDataPath() { + return dataPath; + } + + /** + * Return the value of "mapred.map.tasks". + * + * @param conf + * configuration + * @return number of map tasks + */ + public static int getNumMaps(Configuration conf) { + return conf.getInt("mapred.map.tasks", -1); + } + + /** + * Used only for DEBUG purposes. if false, the mappers doesn't output anything, so the builder has nothing + * to process + * + * @param conf + * configuration + * @return true if the builder has to return output. false otherwise + */ + protected static boolean isOutput(Configuration conf) { + return conf.getBoolean("debug.mahout.rf.output", true); + } + + /** + * Returns the random seed + * + * @param conf + * configuration + * @return null if no seed is available + */ + public static Long getRandomSeed(Configuration conf) { + String seed = conf.get("mahout.rf.random.seed"); + if (seed == null) { + return null; + } + + return Long.valueOf(seed); + } + + /** + * Sets the random seed value + * + * @param conf + * configuration + * @param seed + * random seed + */ + private static void setRandomSeed(Configuration conf, long seed) { + conf.setLong("mahout.rf.random.seed", seed); + } + + public static TreeBuilder getTreeBuilder(Configuration conf) { + String string = conf.get("mahout.rf.treebuilder"); + if (string == null) { + return null; + } + + return StringUtils.fromString(string); + } + + private static void setTreeBuilder(Configuration conf, TreeBuilder treeBuilder) { + conf.set("mahout.rf.treebuilder", StringUtils.toString(treeBuilder)); + } + + /** + * Get the number of trees for the map-reduce job. + * + * @param conf + * configuration + * @return number of trees to build + */ + public static int getNbTrees(Configuration conf) { + return conf.getInt("mahout.rf.nbtrees", -1); + } + + /** + * Set the number of trees to grow for the map-reduce job + * + * @param conf + * configuration + * @param nbTrees + * number of trees to build + * @throws IllegalArgumentException + * if (nbTrees <= 0) + */ + public static void setNbTrees(Configuration conf, int nbTrees) { + Preconditions.checkArgument(nbTrees > 0, "nbTrees should be greater than 0"); + + conf.setInt("mahout.rf.nbtrees", nbTrees); + } + + /** + * Sets the Output directory name, will be creating in the working directory + * + * @param name + * output dir. name + */ + public void setOutputDirName(String name) { + outputDirName = name; + } + + /** + * Output Directory name + * + * @param conf + * configuration + * @return output dir. path (%WORKING_DIRECTORY%/OUTPUT_DIR_NAME%) + * @throws IOException + * if we cannot get the default FileSystem + */ + protected Path getOutputPath(Configuration conf) throws IOException { + // the output directory is accessed only by this class, so use the default + // file system + FileSystem fs = FileSystem.get(conf); + return new Path(fs.getWorkingDirectory(), outputDirName); + } + + /** + * Helper method. Get a path from the DistributedCache + * + * @param conf + * configuration + * @param index + * index of the path in the DistributedCache files + * @return path from the DistributedCache + * @throws IOException + * if no path is found + */ + public static Path getDistributedCacheFile(Configuration conf, int index) throws IOException { + Path[] files = HadoopUtil.getCachedFiles(conf); + + if (files.length <= index) { + throw new IOException("path not found in the DistributedCache"); + } + + return files[index]; + } + + /** + * Helper method. Load a Dataset stored in the DistributedCache + * + * @param conf + * configuration + * @return loaded Dataset + * @throws IOException + * if we cannot retrieve the Dataset path from the DistributedCache, or the Dataset could not be + * loaded + */ + public static Dataset loadDataset(Configuration conf) throws IOException { + Path datasetPath = getDistributedCacheFile(conf, 0); + + return Dataset.load(conf, datasetPath); + } + + /** + * Used by the inheriting classes to configure the job + * + * + * @param job + * Hadoop's Job + * @throws IOException + * if anything goes wrong while configuring the job + */ + protected abstract void configureJob(Job job) throws IOException; + + /** + * Sequential implementation should override this method to simulate the job execution + * + * @param job + * Hadoop's job + * @return true is the job succeeded + */ + protected boolean runJob(Job job) throws ClassNotFoundException, IOException, InterruptedException { + return job.waitForCompletion(true); + } + + /** + * Parse the output files to extract the trees and pass the predictions to the callback + * + * @param job + * Hadoop's job + * @return Built DecisionForest + * @throws IOException + * if anything goes wrong while parsing the output + */ + protected abstract DecisionForest parseOutput(Job job) throws IOException; + + public DecisionForest build(int nbTrees) + throws IOException, ClassNotFoundException, InterruptedException { + // int numTrees = getNbTrees(conf); + + Path outputPath = getOutputPath(conf); + FileSystem fs = outputPath.getFileSystem(conf); + + // check the output + if (fs.exists(outputPath)) { + throw new IOException("Output path already exists : " + outputPath); + } + + if (seed != null) { + setRandomSeed(conf, seed); + } + setNbTrees(conf, nbTrees); + setTreeBuilder(conf, treeBuilder); + + // put the dataset into the DistributedCache + DistributedCache.addCacheFile(datasetPath.toUri(), conf); + + Job job = new Job(conf, "decision forest builder"); + + log.debug("Configuring the job..."); + configureJob(job); + + log.debug("Running the job..."); + if (!runJob(job)) { + log.error("Job failed!"); + return null; + } + + if (isOutput(conf)) { + log.debug("Parsing the output..."); + DecisionForest forest = parseOutput(job); + HadoopUtil.delete(conf, outputPath); + return forest; + } + + return null; + } + + /** + * sort the splits into order based on size, so that the biggest go first.<br> + * This is the same code used by Hadoop's JobClient. + * + * @param splits + * input splits + */ + public static void sortSplits(InputSplit[] splits) { + Arrays.sort(splits, new Comparator<InputSplit>() { + @Override + public int compare(InputSplit a, InputSplit b) { + try { + long left = a.getLength(); + long right = b.getLength(); + if (left == right) { + return 0; + } else if (left < right) { + return 1; + } else { + return -1; + } + } catch (IOException ie) { + throw new IllegalStateException("Problem getting input split size", ie); + } catch (InterruptedException ie) { + throw new IllegalStateException("Problem getting input split size", ie); + } + } + }); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java new file mode 100644 index 0000000..1a35cfe --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java @@ -0,0 +1,238 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.JobContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.FileSplit; +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.data.DataConverter; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Mapreduce implementation that classifies the Input data using a previousely built decision forest + */ +@Deprecated +public class Classifier { + + private static final Logger log = LoggerFactory.getLogger(Classifier.class); + + private final Path forestPath; + private final Path inputPath; + private final Path datasetPath; + private final Configuration conf; + private final Path outputPath; // path that will containt the final output of the classifier + private final Path mappersOutputPath; // mappers will output here + private double[][] results; + + public double[][] getResults() { + return results; + } + + public Classifier(Path forestPath, + Path inputPath, + Path datasetPath, + Path outputPath, + Configuration conf) { + this.forestPath = forestPath; + this.inputPath = inputPath; + this.datasetPath = datasetPath; + this.outputPath = outputPath; + this.conf = conf; + + mappersOutputPath = new Path(outputPath, "mappers"); + } + + private void configureJob(Job job) throws IOException { + + job.setJarByClass(Classifier.class); + + FileInputFormat.setInputPaths(job, inputPath); + FileOutputFormat.setOutputPath(job, mappersOutputPath); + + job.setOutputKeyClass(DoubleWritable.class); + job.setOutputValueClass(Text.class); + + job.setMapperClass(CMapper.class); + job.setNumReduceTasks(0); // no reducers + + job.setInputFormatClass(CTextInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + } + + public void run() throws IOException, ClassNotFoundException, InterruptedException { + FileSystem fs = FileSystem.get(conf); + + // check the output + if (fs.exists(outputPath)) { + throw new IOException("Output path already exists : " + outputPath); + } + + log.info("Adding the dataset to the DistributedCache"); + // put the dataset into the DistributedCache + DistributedCache.addCacheFile(datasetPath.toUri(), conf); + + log.info("Adding the decision forest to the DistributedCache"); + DistributedCache.addCacheFile(forestPath.toUri(), conf); + + Job job = new Job(conf, "decision forest classifier"); + + log.info("Configuring the job..."); + configureJob(job); + + log.info("Running the job..."); + if (!job.waitForCompletion(true)) { + throw new IllegalStateException("Job failed!"); + } + + parseOutput(job); + + HadoopUtil.delete(conf, mappersOutputPath); + } + + /** + * Extract the prediction for each mapper and write them in the corresponding output file. + * The name of the output file is based on the name of the corresponding input file. + * Will compute the ConfusionMatrix if necessary. + */ + private void parseOutput(JobContext job) throws IOException { + Configuration conf = job.getConfiguration(); + FileSystem fs = mappersOutputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, mappersOutputPath); + + // read all the output + List<double[]> resList = new ArrayList<>(); + for (Path path : outfiles) { + FSDataOutputStream ofile = null; + try { + for (Pair<DoubleWritable,Text> record : new SequenceFileIterable<DoubleWritable,Text>(path, true, conf)) { + double key = record.getFirst().get(); + String value = record.getSecond().toString(); + if (ofile == null) { + // this is the first value, it contains the name of the input file + ofile = fs.create(new Path(outputPath, value).suffix(".out")); + } else { + // The key contains the correct label of the data. The value contains a prediction + ofile.writeChars(value); // write the prediction + ofile.writeChar('\n'); + + resList.add(new double[]{key, Double.valueOf(value)}); + } + } + } finally { + Closeables.close(ofile, false); + } + } + results = new double[resList.size()][2]; + resList.toArray(results); + } + + /** + * TextInputFormat that does not split the input files. This ensures that each input file is processed by one single + * mapper. + */ + private static class CTextInputFormat extends TextInputFormat { + @Override + protected boolean isSplitable(JobContext jobContext, Path path) { + return false; + } + } + + public static class CMapper extends Mapper<LongWritable, Text, DoubleWritable, Text> { + + /** used to convert input values to data instances */ + private DataConverter converter; + private DecisionForest forest; + private final Random rng = RandomUtils.getRandom(); + private boolean first = true; + private final Text lvalue = new Text(); + private Dataset dataset; + private final DoubleWritable lkey = new DoubleWritable(); + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); //To change body of overridden methods use File | Settings | File Templates. + + Configuration conf = context.getConfiguration(); + + Path[] files = HadoopUtil.getCachedFiles(conf); + + if (files.length < 2) { + throw new IOException("not enough paths in the DistributedCache"); + } + dataset = Dataset.load(conf, files[0]); + converter = new DataConverter(dataset); + + forest = DecisionForest.load(conf, files[1]); + if (forest == null) { + throw new InterruptedException("DecisionForest not found!"); + } + } + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { + if (first) { + FileSplit split = (FileSplit) context.getInputSplit(); + Path path = split.getPath(); // current split path + lvalue.set(path.getName()); + lkey.set(key.get()); + context.write(lkey, lvalue); + + first = false; + } + + String line = value.toString(); + if (!line.isEmpty()) { + Instance instance = converter.convert(line); + double prediction = forest.classify(dataset, rng, instance); + lkey.set(dataset.getLabel(instance)); + lvalue.set(Double.toString(prediction)); + context.write(lkey, lvalue); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java new file mode 100644 index 0000000..4d0f3f1 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Dataset; + +import java.io.IOException; + +/** + * Base class for Mapred mappers. Loads common parameters from the job + */ +@Deprecated +public class MapredMapper<KEYIN,VALUEIN,KEYOUT,VALUEOUT> extends Mapper<KEYIN,VALUEIN,KEYOUT,VALUEOUT> { + + private boolean noOutput; + + private TreeBuilder treeBuilder; + + private Dataset dataset; + + /** + * + * @return whether the mapper does estimate and output predictions + */ + protected boolean isOutput() { + return !noOutput; + } + + protected TreeBuilder getTreeBuilder() { + return treeBuilder; + } + + protected Dataset getDataset() { + return dataset; + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + + configure(!Builder.isOutput(conf), Builder.getTreeBuilder(conf), Builder + .loadDataset(conf)); + } + + /** + * Useful for testing + */ + protected void configure(boolean noOutput, TreeBuilder treeBuilder, Dataset dataset) { + Preconditions.checkArgument(treeBuilder != null, "TreeBuilder not found in the Job parameters"); + this.noOutput = noOutput; + this.treeBuilder = treeBuilder; + this.dataset = dataset; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java new file mode 100644 index 0000000..56cabb2 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.node.Node; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; + +/** + * Used by various implementation to return the results of a build.<br> + * Contains a grown tree and and its oob predictions. + */ +@Deprecated +public class MapredOutput implements Writable, Cloneable { + + private Node tree; + + private int[] predictions; + + public MapredOutput() { + } + + public MapredOutput(Node tree, int[] predictions) { + this.tree = tree; + this.predictions = predictions; + } + + public MapredOutput(Node tree) { + this(tree, null); + } + + public Node getTree() { + return tree; + } + + int[] getPredictions() { + return predictions; + } + + @Override + public void readFields(DataInput in) throws IOException { + boolean readTree = in.readBoolean(); + if (readTree) { + tree = Node.read(in); + } + + boolean readPredictions = in.readBoolean(); + if (readPredictions) { + predictions = DFUtils.readIntArray(in); + } + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeBoolean(tree != null); + if (tree != null) { + tree.write(out); + } + + out.writeBoolean(predictions != null); + if (predictions != null) { + DFUtils.writeArray(out, predictions); + } + } + + @Override + public MapredOutput clone() { + return new MapredOutput(tree, predictions); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof MapredOutput)) { + return false; + } + + MapredOutput mo = (MapredOutput) obj; + + return ((tree == null && mo.getTree() == null) || (tree != null && tree.equals(mo.getTree()))) + && Arrays.equals(predictions, mo.getPredictions()); + } + + @Override + public int hashCode() { + int hashCode = tree == null ? 1 : tree.hashCode(); + for (int prediction : predictions) { + hashCode = 31 * hashCode + prediction; + } + return hashCode; + } + + @Override + public String toString() { + return "{" + tree + " | " + Arrays.toString(predictions) + '}'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java new file mode 100644 index 0000000..86d4404 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.mapreduce.inmem; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.classifier.df.DFUtils; +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.mapreduce.Builder; +import org.apache.mahout.classifier.df.mapreduce.MapredOutput; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; + +/** + * MapReduce implementation where each mapper loads a full copy of the data in-memory. The forest trees are + * splitted across all the mappers + */ +@Deprecated +public class InMemBuilder extends Builder { + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed, Configuration conf) { + super(treeBuilder, dataPath, datasetPath, seed, conf); + } + + public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath) { + this(treeBuilder, dataPath, datasetPath, null, new Configuration()); + } + + @Override + protected void configureJob(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + job.setJarByClass(InMemBuilder.class); + + FileOutputFormat.setOutputPath(job, getOutputPath(conf)); + + // put the data in the DistributedCache + DistributedCache.addCacheFile(getDataPath().toUri(), conf); + + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(MapredOutput.class); + + job.setMapperClass(InMemMapper.class); + job.setNumReduceTasks(0); // no reducers + + job.setInputFormatClass(InMemInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + + } + + @Override + protected DecisionForest parseOutput(Job job) throws IOException { + Configuration conf = job.getConfiguration(); + + Map<Integer,MapredOutput> output = new HashMap<>(); + + Path outputPath = getOutputPath(conf); + FileSystem fs = outputPath.getFileSystem(conf); + + Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath); + + // import the InMemOutputs + for (Path path : outfiles) { + for (Pair<IntWritable,MapredOutput> record : new SequenceFileIterable<IntWritable,MapredOutput>(path, conf)) { + output.put(record.getFirst().get(), record.getSecond()); + } + } + + return processOutput(output); + } + + /** + * Process the output, extracting the trees + */ + private static DecisionForest processOutput(Map<Integer,MapredOutput> output) { + List<Node> trees = new ArrayList<>(); + + for (Map.Entry<Integer,MapredOutput> entry : output.entrySet()) { + MapredOutput value = entry.getValue(); + trees.add(value.getTree()); + } + + return new DecisionForest(trees); + } +}
