http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java new file mode 100644 index 0000000..1490761 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.recommender; + +/** + * <p> + * A {@link Rescorer} simply assigns a new "score" to a thing like an ID of an item or user which a + * {@link Recommender} is considering returning as a top recommendation. It may be used to arbitrarily re-rank + * the results according to application-specific logic before returning recommendations. For example, an + * application may want to boost the score of items in a certain category just for one request. + * </p> + * + * <p> + * A {@link Rescorer} can also exclude a thing from consideration entirely by returning {@code true} from + * {@link #isFiltered(Object)}. + * </p> + */ +public interface Rescorer<T> { + + /** + * @param thing + * thing to rescore + * @param originalScore + * original score + * @return modified score, or {@link Double#NaN} to indicate that this should be excluded entirely + */ + double rescore(T thing, double originalScore); + + /** + * Returns {@code true} to exclude the given thing. + * + * @param thing + * the thing to filter + * @return {@code true} to exclude, {@code false} otherwise + */ + boolean isFiltered(T thing); +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java new file mode 100644 index 0000000..b48593a --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.cf.taste.recommender; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.common.LongPair; + +/** + * <p> + * Interface implemented by "user-based" recommenders. + * </p> + */ +public interface UserBasedRecommender extends Recommender { + + /** + * @param userID + * ID of user for which to find most similar other users + * @param howMany + * desired number of most similar users to find + * @return users most similar to the given user + * @throws TasteException + * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel} + */ + long[] mostSimilarUserIDs(long userID, int howMany) throws TasteException; + + /** + * @param userID + * ID of user for which to find most similar other users + * @param howMany + * desired number of most similar users to find + * @param rescorer + * {@link Rescorer} which can adjust user-user similarity estimates used to determine most similar + * users + * @return IDs of users most similar to the given user + * @throws TasteException + * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel} + */ + long[] mostSimilarUserIDs(long userID, int howMany, Rescorer<LongPair> rescorer) throws TasteException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java new file mode 100644 index 0000000..814610b --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; + +/** + * <p> + * Implementations of this interface define a notion of similarity between two items. Implementations should + * return values in the range -1.0 to 1.0, with 1.0 representing perfect similarity. + * </p> + * + * @see UserSimilarity + */ +public interface ItemSimilarity extends Refreshable { + + /** + * <p> + * Returns the degree of similarity, of two items, based on the preferences that users have expressed for + * the items. + * </p> + * + * @param itemID1 first item ID + * @param itemID2 second item ID + * @return similarity between the items, in [-1,1] or {@link Double#NaN} similarity is unknown + * @throws org.apache.mahout.cf.taste.common.NoSuchItemException + * if either item is known to be non-existent in the data + * @throws TasteException if an error occurs while accessing the data + */ + double itemSimilarity(long itemID1, long itemID2) throws TasteException; + + /** + * <p>A bulk-get version of {@link #itemSimilarity(long, long)}.</p> + * + * @param itemID1 first item ID + * @param itemID2s second item IDs to compute similarity with + * @return similarity between itemID1 and other items + * @throws org.apache.mahout.cf.taste.common.NoSuchItemException + * if any item is known to be non-existent in the data + * @throws TasteException if an error occurs while accessing the data + */ + double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException; + + /** + * @return all IDs of similar items, in no particular order + */ + long[] allSimilarItemIDs(long itemID) throws TasteException; +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java new file mode 100644 index 0000000..76bb328 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; + +/** + * <p> + * Implementations of this interface compute an inferred preference for a user and an item that the user has + * not expressed any preference for. This might be an average of other preferences scores from that user, for + * example. This technique is sometimes called "default voting". + * </p> + */ +public interface PreferenceInferrer extends Refreshable { + + /** + * <p> + * Infers the given user's preference value for an item. + * </p> + * + * @param userID + * ID of user to infer preference for + * @param itemID + * item ID to infer preference for + * @return inferred preference + * @throws TasteException + * if an error occurs while inferring + */ + float inferPreference(long userID, long itemID) throws TasteException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java new file mode 100644 index 0000000..bd53c51 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; + +/** + * <p> + * Implementations of this interface define a notion of similarity between two users. Implementations should + * return values in the range -1.0 to 1.0, with 1.0 representing perfect similarity. + * </p> + * + * @see ItemSimilarity + */ +public interface UserSimilarity extends Refreshable { + + /** + * <p> + * Returns the degree of similarity, of two users, based on the their preferences. + * </p> + * + * @param userID1 first user ID + * @param userID2 second user ID + * @return similarity between the users, in [-1,1] or {@link Double#NaN} similarity is unknown + * @throws org.apache.mahout.cf.taste.common.NoSuchUserException + * if either user is known to be non-existent in the data + * @throws TasteException if an error occurs while accessing the data + */ + double userSimilarity(long userID1, long userID2) throws TasteException; + + // Should we implement userSimilarities() like ItemSimilarity.itemSimilarities()? + + /** + * <p> + * Attaches a {@link PreferenceInferrer} to the {@link UserSimilarity} implementation. + * </p> + * + * @param inferrer {@link PreferenceInferrer} + */ + void setPreferenceInferrer(PreferenceInferrer inferrer); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java new file mode 100644 index 0000000..b934d0c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity.precompute; + +import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender; + +import java.io.IOException; + +public abstract class BatchItemSimilarities { + + private final ItemBasedRecommender recommender; + private final int similarItemsPerItem; + + /** + * @param recommender recommender to use + * @param similarItemsPerItem number of similar items to compute per item + */ + protected BatchItemSimilarities(ItemBasedRecommender recommender, int similarItemsPerItem) { + this.recommender = recommender; + this.similarItemsPerItem = similarItemsPerItem; + } + + protected ItemBasedRecommender getRecommender() { + return recommender; + } + + protected int getSimilarItemsPerItem() { + return similarItemsPerItem; + } + + /** + * @param degreeOfParallelism number of threads to use for the computation + * @param maxDurationInHours maximum duration of the computation + * @param writer {@link SimilarItemsWriter} used to persist the results + * @return the number of similarities precomputed + * @throws IOException + * @throws RuntimeException if the computation takes longer than maxDurationInHours + */ + public abstract int computeItemSimilarities(int degreeOfParallelism, int maxDurationInHours, + SimilarItemsWriter writer) throws IOException; +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java new file mode 100644 index 0000000..5d40051 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity.precompute; + +import com.google.common.primitives.Doubles; + +import java.util.Comparator; + +/** + * Modeling similarity towards another item + */ +public class SimilarItem { + + public static final Comparator<SimilarItem> COMPARE_BY_SIMILARITY = new Comparator<SimilarItem>() { + @Override + public int compare(SimilarItem s1, SimilarItem s2) { + return Doubles.compare(s1.similarity, s2.similarity); + } + }; + + private long itemID; + private double similarity; + + public SimilarItem(long itemID, double similarity) { + set(itemID, similarity); + } + + public void set(long itemID, double similarity) { + this.itemID = itemID; + this.similarity = similarity; + } + + public long getItemID() { + return itemID; + } + + public double getSimilarity() { + return similarity; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java new file mode 100644 index 0000000..057e996 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java @@ -0,0 +1,84 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity.precompute; + +import com.google.common.collect.UnmodifiableIterator; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; + +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +/** + * Compact representation of all similar items for an item + */ +public class SimilarItems { + + private final long itemID; + private final long[] similarItemIDs; + private final double[] similarities; + + public SimilarItems(long itemID, List<RecommendedItem> similarItems) { + this.itemID = itemID; + + int numSimilarItems = similarItems.size(); + similarItemIDs = new long[numSimilarItems]; + similarities = new double[numSimilarItems]; + + for (int n = 0; n < numSimilarItems; n++) { + similarItemIDs[n] = similarItems.get(n).getItemID(); + similarities[n] = similarItems.get(n).getValue(); + } + } + + public long getItemID() { + return itemID; + } + + public int numSimilarItems() { + return similarItemIDs.length; + } + + public Iterable<SimilarItem> getSimilarItems() { + return new Iterable<SimilarItem>() { + @Override + public Iterator<SimilarItem> iterator() { + return new SimilarItemsIterator(); + } + }; + } + + private class SimilarItemsIterator extends UnmodifiableIterator<SimilarItem> { + + private int index = -1; + + @Override + public boolean hasNext() { + return index < (similarItemIDs.length - 1); + } + + @Override + public SimilarItem next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + index++; + return new SimilarItem(similarItemIDs[index], similarities[index]); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java new file mode 100644 index 0000000..35d6bfe --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity.precompute; + +import java.io.Closeable; +import java.io.IOException; + +/** + * Used to persist the results of a batch item similarity computation + * conducted with a {@link BatchItemSimilarities} implementation + */ +public interface SimilarItemsWriter extends Closeable { + + void open() throws IOException; + + void add(SimilarItems similarItems) throws IOException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java new file mode 100644 index 0000000..efd233f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java @@ -0,0 +1,248 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +import com.google.common.base.Preconditions; + +/** + * Defines the interface for classifiers that take a vector as input. This is + * implemented as an abstract class so that it can implement a number of handy + * convenience methods related to classification of vectors. + * + * <p> + * A classifier takes an input vector and calculates the scores (usually + * probabilities) that the input vector belongs to one of {@code n} + * categories. In {@code AbstractVectorClassifier} each category is denoted + * by an integer {@code c} between {@code 0} and {@code n-1} + * (inclusive). + * + * <p> + * New users should start by looking at {@link #classifyFull} (not {@link #classify}). + * + */ +public abstract class AbstractVectorClassifier { + + /** Minimum allowable log likelihood value. */ + public static final double MIN_LOG_LIKELIHOOD = -100.0; + + /** + * Returns the number of categories that a target variable can be assigned to. + * A vector classifier will encode it's output as an integer from + * {@code 0} to {@code numCategories()-1} (inclusive). + * + * @return The number of categories. + */ + public abstract int numCategories(); + + /** + * Compute and return a vector containing {@code n-1} scores, where + * {@code n} is equal to {@code numCategories()}, given an input + * vector {@code instance}. Higher scores indicate that the input vector + * is more likely to belong to that category. The categories are denoted by + * the integers {@code 0} through {@code n-1} (inclusive), and the + * scores in the returned vector correspond to categories 1 through + * {@code n-1} (leaving out category 0). It is assumed that the score for + * category 0 is one minus the sum of the scores in the returned vector. + * + * @param instance A feature vector to be classified. + * @return A vector of probabilities in 1 of {@code n-1} encoding. + */ + public abstract Vector classify(Vector instance); + + /** + * Compute and return a vector of scores before applying the inverse link + * function. For logistic regression and other generalized linear models, this + * is just the linear part of the classification. + * + * <p> + * The implementation of this method provided by {@code AbstractVectorClassifier} throws an + * {@link UnsupportedOperationException}. Your subclass must explicitly override this method to support + * this operation. + * + * @param features A feature vector to be classified. + * @return A vector of scores. If transformed by the link function, these will become probabilities. + */ + public Vector classifyNoLink(Vector features) { + throw new UnsupportedOperationException(this.getClass().getName() + + " doesn't support classification without a link"); + } + + /** + * Classifies a vector in the special case of a binary classifier where + * {@link #classify(Vector)} would return a vector with only one element. As + * such, using this method can avoid the allocation of a vector. + * + * @param instance The feature vector to be classified. + * @return The score for category 1. + * + * @see #classify(Vector) + */ + public abstract double classifyScalar(Vector instance); + + /** + * Computes and returns a vector containing {@code n} scores, where + * {@code n} is {@code numCategories()}, given an input vector + * {@code instance}. Higher scores indicate that the input vector is more + * likely to belong to the corresponding category. The categories are denoted + * by the integers {@code 0} through {@code n-1} (inclusive). + * + * <p> + * Using this method it is possible to classify an input vector, for example, + * by selecting the category with the largest score. If + * {@code classifier} is an instance of + * {@code AbstractVectorClassifier} and {@code input} is a + * {@code Vector} of features describing an element to be classified, + * then the following code could be used to classify {@code input}.<br> + * {@code + * Vector scores = classifier.classifyFull(input);<br> + * int assignedCategory = scores.maxValueIndex();<br> + * } Here {@code assignedCategory} is the index of the category + * with the maximum score. + * + * <p> + * If an {@code n-1} encoding is acceptable, and allocation performance + * is an issue, then the {@link #classify(Vector)} method is probably better + * to use. + * + * @see #classify(Vector) + * @see #classifyFull(Vector r, Vector instance) + * + * @param instance A vector of features to be classified. + * @return A vector of probabilities, one for each category. + */ + public Vector classifyFull(Vector instance) { + return classifyFull(new DenseVector(numCategories()), instance); + } + + /** + * Computes and returns a vector containing {@code n} scores, where + * {@code n} is {@code numCategories()}, given an input vector + * {@code instance}. Higher scores indicate that the input vector is more + * likely to belong to the corresponding category. The categories are denoted + * by the integers {@code 0} through {@code n-1} (inclusive). The + * main difference between this method and {@link #classifyFull(Vector)} is + * that this method allows a user to provide a previously allocated + * {@code Vector r} to store the returned scores. + * + * <p> + * Using this method it is possible to classify an input vector, for example, + * by selecting the category with the largest score. If + * {@code classifier} is an instance of + * {@code AbstractVectorClassifier}, {@code result} is a non-null + * {@code Vector}, and {@code input} is a {@code Vector} of + * features describing an element to be classified, then the following code + * could be used to classify {@code input}.<br> + * {@code + * Vector scores = classifier.classifyFull(result, input); // Notice that scores == result<br> + * int assignedCategory = scores.maxValueIndex();<br> + * } Here {@code assignedCategory} is the index of the category + * with the maximum score. + * + * @param r Where to put the results. + * @param instance A vector of features to be classified. + * @return A vector of scores/probabilities, one for each category. + */ + public Vector classifyFull(Vector r, Vector instance) { + r.viewPart(1, numCategories() - 1).assign(classify(instance)); + r.setQuick(0, 1.0 - r.zSum()); + return r; + } + + + /** + * Returns n-1 probabilities, one for each categories 1 through + * {@code n-1}, for each row of a matrix, where {@code n} is equal + * to {@code numCategories()}. The probability of the missing 0-th + * category is 1 - rowSum(this result). + * + * @param data The matrix whose rows are the input vectors to classify + * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category. + */ + public Matrix classify(Matrix data) { + Matrix r = new DenseMatrix(data.numRows(), numCategories() - 1); + for (int row = 0; row < data.numRows(); row++) { + r.assignRow(row, classify(data.viewRow(row))); + } + return r; + } + + /** + * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category. + * + * @param data The matrix whose rows are the input vectors to classify + * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category. + */ + public Matrix classifyFull(Matrix data) { + Matrix r = new DenseMatrix(data.numRows(), numCategories()); + for (int row = 0; row < data.numRows(); row++) { + classifyFull(r.viewRow(row), data.viewRow(row)); + } + return r; + } + + /** + * Returns a vector of probabilities of category 1, one for each row + * of a matrix. This only makes sense if there are exactly two categories, but + * calling this method in that case can save a number of vector allocations. + * + * @param data The matrix whose rows are vectors to classify + * @return A vector of scores, with one value per row of the input matrix. + */ + public Vector classifyScalar(Matrix data) { + Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories"); + + Vector r = new DenseVector(data.numRows()); + for (int row = 0; row < data.numRows(); row++) { + r.set(row, classifyScalar(data.viewRow(row))); + } + return r; + } + + /** + * Returns a measure of how good the classification for a particular example + * actually is. + * + * @param actual The correct category for the example. + * @param data The vector to be classified. + * @return The log likelihood of the correct answer as estimated by the current model. This will always be <= 0 + * and larger (closer to 0) indicates better accuracy. In order to simplify code that maintains eunning averages, + * we bound this value at -100. + */ + public double logLikelihood(int actual, Vector data) { + if (numCategories() == 2) { + double p = classifyScalar(data); + if (actual > 0) { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p)); + } else { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p)); + } + } else { + Vector p = classify(data); + if (actual > 0) { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p.get(actual - 1))); + } else { + return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p.zSum())); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java new file mode 100644 index 0000000..29eaa0d --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +/** + * Result of a document classification. The label and the associated score (usually probabilty) + */ +public class ClassifierResult { + + private String label; + private double score; + private double logLikelihood = Double.MAX_VALUE; + + public ClassifierResult() { } + + public ClassifierResult(String label, double score) { + this.label = label; + this.score = score; + } + + public ClassifierResult(String label) { + this.label = label; + } + + public ClassifierResult(String label, double score, double logLikelihood) { + this.label = label; + this.score = score; + this.logLikelihood = logLikelihood; + } + + public double getLogLikelihood() { + return logLikelihood; + } + + public void setLogLikelihood(double logLikelihood) { + this.logLikelihood = logLikelihood; + } + + public String getLabel() { + return label; + } + + public double getScore() { + return score; + } + + public void setLabel(String label) { + this.label = label; + } + + public void setScore(double score) { + this.score = score; + } + + @Override + public String toString() { + return "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}'; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java new file mode 100644 index 0000000..73ba521 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java @@ -0,0 +1,444 @@ +/** + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * Licensed to the Apache Software Foundation (ASF) under one or more + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import com.google.common.base.Preconditions; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.math3.stat.descriptive.moment.Mean; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.Matrix; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The ConfusionMatrix Class stores the result of Classification of a Test Dataset. + * + * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default. + * + * See http://en.wikipedia.org/wiki/Confusion_matrix for background + */ +public class ConfusionMatrix { + private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class); + private final Map<String,Integer> labelMap = new LinkedHashMap<>(); + private final int[][] confusionMatrix; + private int samples = 0; + private String defaultLabel = "unknown"; + + public ConfusionMatrix(Collection<String> labels, String defaultLabel) { + confusionMatrix = new int[labels.size() + 1][labels.size() + 1]; + this.defaultLabel = defaultLabel; + int i = 0; + for (String label : labels) { + labelMap.put(label, i++); + } + labelMap.put(defaultLabel, i); + } + + public ConfusionMatrix(Matrix m) { + confusionMatrix = new int[m.numRows()][m.numRows()]; + setMatrix(m); + } + + public int[][] getConfusionMatrix() { + return confusionMatrix; + } + + public Collection<String> getLabels() { + return Collections.unmodifiableCollection(labelMap.keySet()); + } + + private int numLabels() { + return labelMap.size(); + } + + public double getAccuracy(String label) { + int labelId = labelMap.get(label); + int labelTotal = 0; + int correct = 0; + for (int i = 0; i < numLabels(); i++) { + labelTotal += confusionMatrix[labelId][i]; + if (i == labelId) { + correct += confusionMatrix[labelId][i]; + } + } + return 100.0 * correct / labelTotal; + } + + // Producer accuracy + public double getAccuracy() { + int total = 0; + int correct = 0; + for (int i = 0; i < numLabels(); i++) { + for (int j = 0; j < numLabels(); j++) { + total += confusionMatrix[i][j]; + if (i == j) { + correct += confusionMatrix[i][j]; + } + } + } + return 100.0 * correct / total; + } + + /** Sum of true positives and false negatives */ + private int getActualNumberOfTestExamplesForClass(String label) { + int labelId = labelMap.get(label); + int sum = 0; + for (int i = 0; i < numLabels(); i++) { + sum += confusionMatrix[labelId][i]; + } + return sum; + } + + public double getPrecision(String label) { + int labelId = labelMap.get(label); + int truePositives = confusionMatrix[labelId][labelId]; + int falsePositives = 0; + for (int i = 0; i < numLabels(); i++) { + if (i == labelId) { + continue; + } + falsePositives += confusionMatrix[i][labelId]; + } + + if (truePositives + falsePositives == 0) { + return 0; + } + + return ((double) truePositives) / (truePositives + falsePositives); + } + + public double getWeightedPrecision() { + double[] precisions = new double[numLabels()]; + double[] weights = new double[numLabels()]; + + int index = 0; + for (String label : labelMap.keySet()) { + precisions[index] = getPrecision(label); + weights[index] = getActualNumberOfTestExamplesForClass(label); + index++; + } + return new Mean().evaluate(precisions, weights); + } + + public double getRecall(String label) { + int labelId = labelMap.get(label); + int truePositives = confusionMatrix[labelId][labelId]; + int falseNegatives = 0; + for (int i = 0; i < numLabels(); i++) { + if (i == labelId) { + continue; + } + falseNegatives += confusionMatrix[labelId][i]; + } + if (truePositives + falseNegatives == 0) { + return 0; + } + return ((double) truePositives) / (truePositives + falseNegatives); + } + + public double getWeightedRecall() { + double[] recalls = new double[numLabels()]; + double[] weights = new double[numLabels()]; + + int index = 0; + for (String label : labelMap.keySet()) { + recalls[index] = getRecall(label); + weights[index] = getActualNumberOfTestExamplesForClass(label); + index++; + } + return new Mean().evaluate(recalls, weights); + } + + public double getF1score(String label) { + double precision = getPrecision(label); + double recall = getRecall(label); + if (precision + recall == 0) { + return 0; + } + return 2 * precision * recall / (precision + recall); + } + + public double getWeightedF1score() { + double[] f1Scores = new double[numLabels()]; + double[] weights = new double[numLabels()]; + + int index = 0; + for (String label : labelMap.keySet()) { + f1Scores[index] = getF1score(label); + weights[index] = getActualNumberOfTestExamplesForClass(label); + index++; + } + return new Mean().evaluate(f1Scores, weights); + } + + // User accuracy + public double getReliability() { + int count = 0; + double accuracy = 0; + for (String label: labelMap.keySet()) { + if (!label.equals(defaultLabel)) { + accuracy += getAccuracy(label); + } + count++; + } + return accuracy / count; + } + + /** + * Accuracy v.s. randomly classifying all samples. + * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy()) + * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales. + * Educational And Psychological Measurement 20:37-46. + * + * Formula and variable names from: + * http://www.yale.edu/ceo/OEFS/Accuracy.pdf + * + * @return double + */ + public double getKappa() { + double a = 0.0; + double b = 0.0; + for (int i = 0; i < confusionMatrix.length; i++) { + a += confusionMatrix[i][i]; + double br = 0; + for (int j = 0; j < confusionMatrix.length; j++) { + br += confusionMatrix[i][j]; + } + double bc = 0; + for (int[] vec : confusionMatrix) { + bc += vec[i]; + } + b += br * bc; + } + return (samples * a - b) / (samples * samples - b); + } + + /** + * Standard deviation of normalized producer accuracy + * Not a standard score + * @return double + */ + public RunningAverageAndStdDev getNormalizedStats() { + RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev(); + for (int d = 0; d < confusionMatrix.length; d++) { + double total = 0; + for (int j = 0; j < confusionMatrix.length; j++) { + total += confusionMatrix[d][j]; + } + summer.addDatum(confusionMatrix[d][d] / (total + 0.000001)); + } + + return summer; + } + + public int getCorrect(String label) { + int labelId = labelMap.get(label); + return confusionMatrix[labelId][labelId]; + } + + public int getTotal(String label) { + int labelId = labelMap.get(label); + int labelTotal = 0; + for (int i = 0; i < labelMap.size(); i++) { + labelTotal += confusionMatrix[labelId][i]; + } + return labelTotal; + } + + public void addInstance(String correctLabel, ClassifierResult classifiedResult) { + samples++; + incrementCount(correctLabel, classifiedResult.getLabel()); + } + + public void addInstance(String correctLabel, String classifiedLabel) { + samples++; + incrementCount(correctLabel, classifiedLabel); + } + + public int getCount(String correctLabel, String classifiedLabel) { + if(!labelMap.containsKey(correctLabel)) { + LOG.warn("Label {} did not appear in the training examples", correctLabel); + return 0; + } + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel); + int correctId = labelMap.get(correctLabel); + int classifiedId = labelMap.get(classifiedLabel); + return confusionMatrix[correctId][classifiedId]; + } + + public void putCount(String correctLabel, String classifiedLabel, int count) { + if(!labelMap.containsKey(correctLabel)) { + LOG.warn("Label {} did not appear in the training examples", correctLabel); + return; + } + Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel); + int correctId = labelMap.get(correctLabel); + int classifiedId = labelMap.get(classifiedLabel); + if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) { + samples++; + } + confusionMatrix[correctId][classifiedId] = count; + } + + public String getDefaultLabel() { + return defaultLabel; + } + + public void incrementCount(String correctLabel, String classifiedLabel, int count) { + putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel)); + } + + public void incrementCount(String correctLabel, String classifiedLabel) { + incrementCount(correctLabel, classifiedLabel, 1); + } + + public ConfusionMatrix merge(ConfusionMatrix b) { + Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match"); + for (String correctLabel : this.labelMap.keySet()) { + for (String classifiedLabel : this.labelMap.keySet()) { + incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel)); + } + } + return this; + } + + public Matrix getMatrix() { + int length = confusionMatrix.length; + Matrix m = new DenseMatrix(length, length); + for (int r = 0; r < length; r++) { + for (int c = 0; c < length; c++) { + m.set(r, c, confusionMatrix[r][c]); + } + } + Map<String,Integer> labels = new HashMap<>(); + for (Map.Entry<String, Integer> entry : labelMap.entrySet()) { + labels.put(entry.getKey(), entry.getValue()); + } + m.setRowLabelBindings(labels); + m.setColumnLabelBindings(labels); + return m; + } + + public void setMatrix(Matrix m) { + int length = confusionMatrix.length; + if (m.numRows() != m.numCols()) { + throw new IllegalArgumentException( + "ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square"); + } + for (int r = 0; r < length; r++) { + for (int c = 0; c < length; c++) { + confusionMatrix[r][c] = (int) Math.round(m.get(r, c)); + } + } + Map<String,Integer> labels = m.getRowLabelBindings(); + if (labels == null) { + labels = m.getColumnLabelBindings(); + } + if (labels != null) { + String[] sorted = sortLabels(labels); + verifyLabels(length, sorted); + labelMap.clear(); + for (int i = 0; i < length; i++) { + labelMap.put(sorted[i], i); + } + } + } + + private static String[] sortLabels(Map<String,Integer> labels) { + String[] sorted = new String[labels.size()]; + for (Map.Entry<String,Integer> entry : labels.entrySet()) { + sorted[entry.getValue()] = entry.getKey(); + } + return sorted; + } + + private static void verifyLabels(int length, String[] sorted) { + Preconditions.checkArgument(sorted.length == length, "One label, one row"); + for (int i = 0; i < length; i++) { + if (sorted[i] == null) { + Preconditions.checkArgument(false, "One label, one row"); + } + } + } + + /** + * This is overloaded. toString() is not a formatted report you print for a manager :) + * Assume that if there are no default assignments, the default feature was not used + */ + @Override + public String toString() { + StringBuilder returnString = new StringBuilder(200); + returnString.append("=======================================================").append('\n'); + returnString.append("Confusion Matrix\n"); + returnString.append("-------------------------------------------------------").append('\n'); + + int unclassified = getTotal(defaultLabel); + for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { + continue; + } + + returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t'); + } + + returnString.append("<--Classified as").append('\n'); + for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) { + if (entry.getKey().equals(defaultLabel) && unclassified == 0) { + continue; + } + String correctLabel = entry.getKey(); + int labelTotal = 0; + for (String classifiedLabel : this.labelMap.keySet()) { + if (classifiedLabel.equals(defaultLabel) && unclassified == 0) { + continue; + } + returnString.append( + StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t'); + labelTotal += getCount(correctLabel, classifiedLabel); + } + returnString.append(" | ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t') + .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)) + .append(" = ").append(correctLabel).append('\n'); + } + if (unclassified > 0) { + returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n'); + } + returnString.append('\n'); + return returnString.toString(); + } + + static String getSmallLabel(int i) { + int val = i; + StringBuilder returnString = new StringBuilder(); + do { + int n = val % 26; + returnString.insert(0, (char) ('a' + n)); + val /= 26; + } while (val > 0); + return returnString.toString(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java new file mode 100644 index 0000000..af1d5e7 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import org.apache.mahout.math.Vector; + +import java.io.Closeable; + +/** + * The simplest interface for online learning algorithms. + */ +public interface OnlineLearner extends Closeable { + /** + * Updates the model using a particular target variable value and a feature vector. + * <p/> + * There may an assumption that if multiple passes through the training data are necessary, then + * the training examples will be presented in the same order. This is because the order of + * training examples may be used to assign records to different data splits for evaluation by + * cross-validation. Without the order invariance, records might be assigned to training and test + * splits and error estimates could be seriously affected. + * <p/> + * If re-ordering is necessary, then using the alternative API which allows a tracking key to be + * added to the training example can be used. + * + * @param actual The value of the target variable. This value should be in the half-open + * interval [0..n) where n is the number of target categories. + * @param instance The feature vector for this example. + */ + void train(int actual, Vector instance); + + /** + * Updates the model using a particular target variable value and a feature vector. + * <p/> + * There may an assumption that if multiple passes through the training data are necessary that + * the tracking key for a record will be the same for each pass and that there will be a + * relatively large number of distinct tracking keys and that the low-order bits of the tracking + * keys will not correlate with any of the input variables. This tracking key is used to assign + * training examples to different test/training splits. + * <p/> + * Examples of useful tracking keys include id-numbers for the training records derived from + * a database id for the base table from the which the record is derived, or the offset of + * the original data record in a data file. + * + * @param trackingKey The tracking key for this training example. + * @param groupKey An optional value that allows examples to be grouped in the computation of + * the update to the model. + * @param actual The value of the target variable. This value should be in the half-open + * interval [0..n) where n is the number of target categories. + * @param instance The feature vector for this example. + */ + void train(long trackingKey, String groupKey, int actual, Vector instance); + + /** + * Updates the model using a particular target variable value and a feature vector. + * <p/> + * There may an assumption that if multiple passes through the training data are necessary that + * the tracking key for a record will be the same for each pass and that there will be a + * relatively large number of distinct tracking keys and that the low-order bits of the tracking + * keys will not correlate with any of the input variables. This tracking key is used to assign + * training examples to different test/training splits. + * <p/> + * Examples of useful tracking keys include id-numbers for the training records derived from + * a database id for the base table from the which the record is derived, or the offset of + * the original data record in a data file. + * + * @param trackingKey The tracking key for this training example. + * @param actual The value of the target variable. This value should be in the half-open + * interval [0..n) where n is the number of target categories. + * @param instance The feature vector for this example. + */ + void train(long trackingKey, int actual, Vector instance); + + /** + * Prepares the classifier for classification and deallocates any temporary data structures. + * + * An online classifier should be able to accept more training after being closed, but + * closing the classifier may make classification more efficient. + */ + @Override + void close(); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java new file mode 100644 index 0000000..35c11ee --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java @@ -0,0 +1,144 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.text.DecimalFormat; +import java.text.NumberFormat; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import org.apache.commons.lang3.StringUtils; + +/** + * ResultAnalyzer captures the classification statistics and displays in a tabular manner + */ +public class RegressionResultAnalyzer { + + private static class Result { + private final double actual; + private final double result; + Result(double actual, double result) { + this.actual = actual; + this.result = result; + } + double getActual() { + return actual; + } + double getResult() { + return result; + } + } + + private List<Result> results; + + /** + * + * @param actual + * The actual answer + * @param result + * The regression result + */ + public void addInstance(double actual, double result) { + if (results == null) { + results = new ArrayList<>(); + } + results.add(new Result(actual, result)); + } + + /** + * + * @param results + * The results table + */ + public void setInstances(double[][] results) { + for (double[] res : results) { + addInstance(res[0], res[1]); + } + } + + @Override + public String toString() { + double sumActual = 0.0; + double sumActualSquared = 0.0; + double sumResult = 0.0; + double sumResultSquared = 0.0; + double sumActualResult = 0.0; + double sumAbsolute = 0.0; + double sumAbsoluteSquared = 0.0; + int predictable = 0; + int unpredictable = 0; + + for (Result res : results) { + double actual = res.getActual(); + double result = res.getResult(); + if (Double.isNaN(result)) { + unpredictable++; + } else { + sumActual += actual; + sumActualSquared += actual * actual; + sumResult += result; + sumResultSquared += result * result; + sumActualResult += actual * result; + double absolute = Math.abs(actual - result); + sumAbsolute += absolute; + sumAbsoluteSquared += absolute * absolute; + predictable++; + } + } + + StringBuilder returnString = new StringBuilder(); + + returnString.append("=======================================================\n"); + returnString.append("Summary\n"); + returnString.append("-------------------------------------------------------\n"); + + if (predictable > 0) { + double varActual = sumActualSquared - sumActual * sumActual / predictable; + double varResult = sumResultSquared - sumResult * sumResult / predictable; + double varCo = sumActualResult - sumActual * sumResult / predictable; + + double correlation; + if (varActual * varResult <= 0) { + correlation = 0.0; + } else { + correlation = varCo / Math.sqrt(varActual * varResult); + } + + Locale.setDefault(Locale.US); + NumberFormat decimalFormatter = new DecimalFormat("0.####"); + + returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append( + StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append( + StringUtils.leftPad(decimalFormatter.format(sumAbsolute / predictable), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append( + StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / predictable)), + 10)).append('\n'); + } + returnString.append(StringUtils.rightPad("Predictable Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(predictable), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Unpredictable Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(unpredictable), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Total Regressed Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(results.size()), 10)).append('\n'); + returnString.append('\n'); + + return returnString.toString(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java new file mode 100644 index 0000000..1711f19 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java @@ -0,0 +1,132 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import java.text.DecimalFormat; +import java.text.NumberFormat; +import java.util.Collection; + +import org.apache.commons.lang3.StringUtils; +import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev; +import org.apache.mahout.math.stats.OnlineSummarizer; + +/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */ +public class ResultAnalyzer { + + private final ConfusionMatrix confusionMatrix; + private final OnlineSummarizer summarizer; + private boolean hasLL; + + /* + * === Summary === + * + * Correctly Classified Instances 635 92.9722 % Incorrectly Classified Instances 48 7.0278 % Kappa statistic + * 0.923 Mean absolute error 0.0096 Root mean squared error 0.0817 Relative absolute error 9.9344 % Root + * relative squared error 37.2742 % Total Number of Instances 683 + */ + private int correctlyClassified; + private int incorrectlyClassified; + + public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) { + confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel); + summarizer = new OnlineSummarizer(); + } + + public ConfusionMatrix getConfusionMatrix() { + return this.confusionMatrix; + } + + /** + * + * @param correctLabel + * The correct label + * @param classifiedResult + * The classified result + * @return whether the instance was correct or not + */ + public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) { + boolean result = correctLabel.equals(classifiedResult.getLabel()); + if (result) { + correctlyClassified++; + } else { + incorrectlyClassified++; + } + confusionMatrix.addInstance(correctLabel, classifiedResult); + if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) { + summarizer.add(classifiedResult.getLogLikelihood()); + hasLL = true; + } + return result; + } + + @Override + public String toString() { + StringBuilder returnString = new StringBuilder(); + + returnString.append('\n'); + returnString.append("=======================================================\n"); + returnString.append("Summary\n"); + returnString.append("-------------------------------------------------------\n"); + int totalClassified = correctlyClassified + incorrectlyClassified; + double percentageCorrect = (double) 100 * correctlyClassified / totalClassified; + double percentageIncorrect = (double) 100 * incorrectlyClassified / totalClassified; + NumberFormat decimalFormatter = new DecimalFormat("0.####"); + + returnString.append(StringUtils.rightPad("Correctly Classified Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(correctlyClassified), 10)).append('\t').append( + StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Incorrectly Classified Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(incorrectlyClassified), 10)).append('\t').append( + StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Total Classified Instances", 40)).append(": ").append( + StringUtils.leftPad(Integer.toString(totalClassified), 10)).append('\n'); + returnString.append('\n'); + + returnString.append(confusionMatrix); + returnString.append("=======================================================\n"); + returnString.append("Statistics\n"); + returnString.append("-------------------------------------------------------\n"); + + RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats(); + returnString.append(StringUtils.rightPad("Kappa", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Accuracy", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Reliability", 40)).append( + StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n"); + returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append( + StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Weighted precision", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Weighted recall", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append( + StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n'); + + if (hasLL) { + returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean : ").append( + StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n'); + returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile : ", 10)).append( + StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10)).append('\n'); + returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile : ", 10)).append( + StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10)).append('\n'); + } + + return returnString.toString(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java new file mode 100644 index 0000000..f79a429 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import org.apache.mahout.classifier.df.builder.TreeBuilder; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.node.Node; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Random; + +/** + * Builds a tree using bagging + */ +@Deprecated +public class Bagging { + + private static final Logger log = LoggerFactory.getLogger(Bagging.class); + + private final TreeBuilder treeBuilder; + + private final Data data; + + private final boolean[] sampled; + + public Bagging(TreeBuilder treeBuilder, Data data) { + this.treeBuilder = treeBuilder; + this.data = data; + sampled = new boolean[data.size()]; + } + + /** + * Builds one tree + */ + public Node build(Random rng) { + log.debug("Bagging..."); + Arrays.fill(sampled, false); + Data bag = data.bagging(rng, sampled); + + log.debug("Building..."); + return treeBuilder.build(rng, bag); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java new file mode 100644 index 0000000..c94292c --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java @@ -0,0 +1,174 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * <p/> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p/> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; + +/** + * Utility class that contains various helper methods + */ +@Deprecated +public final class DFUtils { + + private DFUtils() { + } + + /** + * Writes an Node[] into a DataOutput + * @throws java.io.IOException + */ + public static void writeArray(DataOutput out, Node[] array) throws IOException { + out.writeInt(array.length); + for (Node w : array) { + w.write(out); + } + } + + /** + * Reads a Node[] from a DataInput + * @throws java.io.IOException + */ + public static Node[] readNodeArray(DataInput in) throws IOException { + int length = in.readInt(); + Node[] nodes = new Node[length]; + for (int index = 0; index < length; index++) { + nodes[index] = Node.read(in); + } + + return nodes; + } + + /** + * Writes a double[] into a DataOutput + * @throws java.io.IOException + */ + public static void writeArray(DataOutput out, double[] array) throws IOException { + out.writeInt(array.length); + for (double value : array) { + out.writeDouble(value); + } + } + + /** + * Reads a double[] from a DataInput + * @throws java.io.IOException + */ + public static double[] readDoubleArray(DataInput in) throws IOException { + int length = in.readInt(); + double[] array = new double[length]; + for (int index = 0; index < length; index++) { + array[index] = in.readDouble(); + } + + return array; + } + + /** + * Writes an int[] into a DataOutput + * @throws java.io.IOException + */ + public static void writeArray(DataOutput out, int[] array) throws IOException { + out.writeInt(array.length); + for (int value : array) { + out.writeInt(value); + } + } + + /** + * Reads an int[] from a DataInput + * @throws java.io.IOException + */ + public static int[] readIntArray(DataInput in) throws IOException { + int length = in.readInt(); + int[] array = new int[length]; + for (int index = 0; index < length; index++) { + array[index] = in.readInt(); + } + + return array; + } + + /** + * Return a list of all files in the output directory + * @throws IOException if no file is found + */ + public static Path[] listOutputFiles(FileSystem fs, Path outputPath) throws IOException { + List<Path> outputFiles = new ArrayList<>(); + for (FileStatus s : fs.listStatus(outputPath, PathFilters.logsCRCFilter())) { + if (!s.isDir() && !s.getPath().getName().startsWith("_")) { + outputFiles.add(s.getPath()); + } + } + if (outputFiles.isEmpty()) { + throw new IOException("No output found !"); + } + return outputFiles.toArray(new Path[outputFiles.size()]); + } + + /** + * Formats a time interval in milliseconds to a String in the form "hours:minutes:seconds:millis" + */ + public static String elapsedTime(long milli) { + long seconds = milli / 1000; + milli %= 1000; + + long minutes = seconds / 60; + seconds %= 60; + + long hours = minutes / 60; + minutes %= 60; + + return hours + "h " + minutes + "m " + seconds + "s " + milli; + } + + public static void storeWritable(Configuration conf, Path path, Writable writable) throws IOException { + FileSystem fs = path.getFileSystem(conf); + + try (FSDataOutputStream out = fs.create(path)) { + writable.write(out); + } + } + + /** + * Write a string to a path. + * @param conf From which the file system will be picked + * @param path Where the string will be written + * @param string The string to write + * @throws IOException if things go poorly + */ + public static void storeString(Configuration conf, Path path, String string) throws IOException { + try (DataOutputStream out = path.getFileSystem(conf).create(path)) { + out.write(string.getBytes(Charset.defaultCharset())); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java new file mode 100644 index 0000000..c11cf34 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java @@ -0,0 +1,241 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataUtils; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.node.Node; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +/** + * Represents a forest of decision trees. + */ +@Deprecated +public class DecisionForest implements Writable { + + private final List<Node> trees; + + private DecisionForest() { + trees = new ArrayList<>(); + } + + public DecisionForest(List<Node> trees) { + Preconditions.checkArgument(trees != null && !trees.isEmpty(), "trees argument must not be null or empty"); + + this.trees = trees; + } + + List<Node> getTrees() { + return trees; + } + + /** + * Classifies the data and calls callback for each classification + */ + public void classify(Data data, double[][] predictions) { + Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()"); + + if (data.isEmpty()) { + return; // nothing to classify + } + + int treeId = 0; + for (Node tree : trees) { + for (int index = 0; index < data.size(); index++) { + if (predictions[index] == null) { + predictions[index] = new double[trees.size()]; + } + predictions[index][treeId] = tree.classify(data.get(index)); + } + treeId++; + } + } + + /** + * predicts the label for the instance + * + * @param rng + * Random number generator, used to break ties randomly + * @return NaN if the label cannot be predicted + */ + public double classify(Dataset dataset, Random rng, Instance instance) { + if (dataset.isNumerical(dataset.getLabelId())) { + double sum = 0; + int cnt = 0; + for (Node tree : trees) { + double prediction = tree.classify(instance); + if (!Double.isNaN(prediction)) { + sum += prediction; + cnt++; + } + } + + if (cnt > 0) { + return sum / cnt; + } else { + return Double.NaN; + } + } else { + int[] predictions = new int[dataset.nblabels()]; + for (Node tree : trees) { + double prediction = tree.classify(instance); + if (!Double.isNaN(prediction)) { + predictions[(int) prediction]++; + } + } + + if (DataUtils.sum(predictions) == 0) { + return Double.NaN; // no prediction available + } + + return DataUtils.maxindex(rng, predictions); + } + } + + /** + * @return Mean number of nodes per tree + */ + public long meanNbNodes() { + long sum = 0; + + for (Node tree : trees) { + sum += tree.nbNodes(); + } + + return sum / trees.size(); + } + + /** + * @return Total number of nodes in all the trees + */ + public long nbNodes() { + long sum = 0; + + for (Node tree : trees) { + sum += tree.nbNodes(); + } + + return sum; + } + + /** + * @return Mean maximum depth per tree + */ + public long meanMaxDepth() { + long sum = 0; + + for (Node tree : trees) { + sum += tree.maxDepth(); + } + + return sum / trees.size(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DecisionForest)) { + return false; + } + + DecisionForest rf = (DecisionForest) obj; + + return trees.size() == rf.getTrees().size() && trees.containsAll(rf.getTrees()); + } + + @Override + public int hashCode() { + return trees.hashCode(); + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + dataOutput.writeInt(trees.size()); + for (Node tree : trees) { + tree.write(dataOutput); + } + } + + /** + * Reads the trees from the input and adds them to the existing trees + */ + @Override + public void readFields(DataInput dataInput) throws IOException { + int size = dataInput.readInt(); + for (int i = 0; i < size; i++) { + trees.add(Node.read(dataInput)); + } + } + + /** + * Read the forest from inputStream + * @param dataInput - input forest + * @return {@link org.apache.mahout.classifier.df.DecisionForest} + * @throws IOException + */ + public static DecisionForest read(DataInput dataInput) throws IOException { + DecisionForest forest = new DecisionForest(); + forest.readFields(dataInput); + return forest; + } + + /** + * Load the forest from a single file or a directory of files + * @throws java.io.IOException + */ + public static DecisionForest load(Configuration conf, Path forestPath) throws IOException { + FileSystem fs = forestPath.getFileSystem(conf); + Path[] files; + if (fs.getFileStatus(forestPath).isDir()) { + files = DFUtils.listOutputFiles(fs, forestPath); + } else { + files = new Path[]{forestPath}; + } + + DecisionForest forest = null; + for (Path path : files) { + try (FSDataInputStream dataInput = new FSDataInputStream(fs.open(path))) { + if (forest == null) { + forest = read(dataInput); + } else { + forest.readFields(dataInput); + } + } + } + + return forest; + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java new file mode 100644 index 0000000..13cd386 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df; + +import com.google.common.base.Preconditions; + +/** + * Various methods to compute from the output of a random forest + */ +@Deprecated +public final class ErrorEstimate { + + private ErrorEstimate() { + } + + public static double errorRate(double[] labels, double[] predictions) { + Preconditions.checkArgument(labels.length == predictions.length, "labels.length != predictions.length"); + double nberrors = 0; // number of instance that got bad predictions + double datasize = 0; // number of classified instances + + for (int index = 0; index < labels.length; index++) { + if (predictions[index] == -1) { + continue; // instance not classified + } + + if (predictions[index] != labels[index]) { + nberrors++; + } + + datasize++; + } + + return nberrors / datasize; + } + +}
