http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java new file mode 100644 index 0000000..a99d54c --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java @@ -0,0 +1,265 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track1.svd; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.common.RandomUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Random; + +/** + * {@link Factorizer} based on Simon Funk's famous article <a href="http://sifter.org/~simon/journal/20061211.html"> + * "Netflix Update: Try this at home"</a>. + * + * Attempts to be as memory efficient as possible, only iterating once through the + * {@link FactorizablePreferences} or {@link DataModel} while copying everything to primitive arrays. + * Learning works in place on these datastructures after that. + */ +public class ParallelArraysSGDFactorizer implements Factorizer { + + public static final double DEFAULT_LEARNING_RATE = 0.005; + public static final double DEFAULT_PREVENT_OVERFITTING = 0.02; + public static final double DEFAULT_RANDOM_NOISE = 0.005; + + private final int numFeatures; + private final int numIterations; + private final float minPreference; + private final float maxPreference; + + private final Random random; + private final double learningRate; + private final double preventOverfitting; + + private final FastByIDMap<Integer> userIDMapping; + private final FastByIDMap<Integer> itemIDMapping; + + private final double[][] userFeatures; + private final double[][] itemFeatures; + + private final int[] userIndexes; + private final int[] itemIndexes; + private final float[] values; + + private final double defaultValue; + private final double interval; + private final double[] cachedEstimates; + + + private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class); + + public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) { + this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE, + DEFAULT_PREVENT_OVERFITTING, DEFAULT_RANDOM_NOISE); + } + + public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate, + double preventOverfitting, double randomNoise) { + this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting, + randomNoise); + } + + public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) { + this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING, + DEFAULT_RANDOM_NOISE); + } + + public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures, + int numIterations, double learningRate, double preventOverfitting, double randomNoise) { + + this.numFeatures = numFeatures; + this.numIterations = numIterations; + minPreference = factorizablePreferences.getMinPreference(); + maxPreference = factorizablePreferences.getMaxPreference(); + + this.random = RandomUtils.getRandom(); + this.learningRate = learningRate; + this.preventOverfitting = preventOverfitting; + + int numUsers = factorizablePreferences.numUsers(); + int numItems = factorizablePreferences.numItems(); + int numPrefs = factorizablePreferences.numPreferences(); + + log.info("Mapping {} users...", numUsers); + userIDMapping = new FastByIDMap<>(numUsers); + int index = 0; + LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs(); + while (userIterator.hasNext()) { + userIDMapping.put(userIterator.nextLong(), index++); + } + + log.info("Mapping {} items", numItems); + itemIDMapping = new FastByIDMap<>(numItems); + index = 0; + LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs(); + while (itemIterator.hasNext()) { + itemIDMapping.put(itemIterator.nextLong(), index++); + } + + this.userIndexes = new int[numPrefs]; + this.itemIndexes = new int[numPrefs]; + this.values = new float[numPrefs]; + this.cachedEstimates = new double[numPrefs]; + + index = 0; + log.info("Loading {} preferences into memory", numPrefs); + RunningAverage average = new FullRunningAverage(); + for (Preference preference : factorizablePreferences.getPreferences()) { + userIndexes[index] = userIDMapping.get(preference.getUserID()); + itemIndexes[index] = itemIDMapping.get(preference.getItemID()); + values[index] = preference.getValue(); + cachedEstimates[index] = 0; + + average.addDatum(preference.getValue()); + + index++; + if (index % 1000000 == 0) { + log.info("Processed {} preferences", index); + } + } + log.info("Processed {} preferences, done.", index); + + double averagePreference = average.getAverage(); + log.info("Average preference value is {}", averagePreference); + + double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference(); + defaultValue = Math.sqrt((averagePreference - prefInterval * 0.1) / numFeatures); + interval = prefInterval * 0.1 / numFeatures; + + userFeatures = new double[numUsers][numFeatures]; + itemFeatures = new double[numItems][numFeatures]; + + log.info("Initializing feature vectors..."); + for (int feature = 0; feature < numFeatures; feature++) { + for (int userIndex = 0; userIndex < numUsers; userIndex++) { + userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise; + } + for (int itemIndex = 0; itemIndex < numItems; itemIndex++) { + itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise; + } + } + } + + @Override + public Factorization factorize() throws TasteException { + for (int feature = 0; feature < numFeatures; feature++) { + log.info("Shuffling preferences..."); + shufflePreferences(); + log.info("Starting training of feature {} ...", feature); + for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) { + if (currentIteration == numIterations - 1) { + double rmse = trainingIterationWithRmse(feature); + log.info("Finished training feature {} with RMSE {}", feature, rmse); + } else { + trainingIteration(feature); + } + } + if (feature < numFeatures - 1) { + log.info("Updating cache..."); + for (int index = 0; index < userIndexes.length; index++) { + cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature, cachedEstimates[index], + false); + } + } + } + log.info("Factorization done"); + return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures); + } + + private void trainingIteration(int feature) { + for (int index = 0; index < userIndexes.length; index++) { + train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]); + } + } + + private double trainingIterationWithRmse(int feature) { + double rmse = 0.0; + for (int index = 0; index < userIndexes.length; index++) { + double error = train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]); + rmse += error * error; + } + return Math.sqrt(rmse / userIndexes.length); + } + + private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) { + double sum = cachedEstimate; + sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature]; + if (trailing) { + sum += (numFeatures - feature - 1) * (defaultValue + interval) * (defaultValue + interval); + if (sum > maxPreference) { + sum = maxPreference; + } else if (sum < minPreference) { + sum = minPreference; + } + } + return sum; + } + + public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) { + double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true); + double[] userVector = userFeatures[userIndex]; + double[] itemVector = itemFeatures[itemIndex]; + + userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting * userVector[feature]); + itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting * itemVector[feature]); + + return error; + } + + protected void shufflePreferences() { + /* Durstenfeld shuffle */ + for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) { + int swapPos = random.nextInt(currentPos + 1); + swapPreferences(currentPos, swapPos); + } + } + + private void swapPreferences(int posA, int posB) { + int tmpUserIndex = userIndexes[posA]; + int tmpItemIndex = itemIndexes[posA]; + float tmpValue = values[posA]; + double tmpEstimate = cachedEstimates[posA]; + + userIndexes[posA] = userIndexes[posB]; + itemIndexes[posA] = itemIndexes[posB]; + values[posA] = values[posB]; + cachedEstimates[posA] = cachedEstimates[posB]; + + userIndexes[posB] = tmpUserIndex; + itemIndexes[posB] = tmpItemIndex; + values[posB] = tmpValue; + cachedEstimates[posB] = tmpEstimate; + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + // do nothing + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java new file mode 100644 index 0000000..5cce02d --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java @@ -0,0 +1,141 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track1.svd; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.OutputStream; + +import com.google.common.io.Closeables; +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.NoSuchUserException; +import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable; +import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; +import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter; +import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; +import org.apache.mahout.cf.taste.impl.common.RunningAverage; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization; +import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * run an SVD factorization of the KDD track1 data. + * + * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M + * + */ +public final class Track1SVDRunner { + + private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class); + + private Track1SVDRunner() { + } + + public static void main(String[] args) throws Exception { + + if (args.length != 2) { + System.err.println("Necessary arguments: <kddDataFileDirectory> <resultFile>"); + return; + } + + File dataFileDirectory = new File(args[0]); + if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) { + throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory); + } + + File resultFile = new File(args[1]); + + /* the knobs to turn */ + int numFeatures = 20; + int numIterations = 5; + double learningRate = 0.0001; + double preventOverfitting = 0.002; + double randomNoise = 0.0001; + + + KDDCupFactorizablePreferences factorizablePreferences = + new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory)); + + Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations, + learningRate, preventOverfitting, randomNoise); + + Factorization factorization = sgdFactorizer.factorize(); + + log.info("Estimating validation preferences..."); + int prefsProcessed = 0; + RunningAverage average = new FullRunningAverage(); + for (Pair<PreferenceArray,long[]> validationPair + : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) { + for (Preference validationPref : validationPair.getFirst()) { + double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(), + factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference()); + double error = validationPref.getValue() - estimate; + average.addDatum(error * error); + prefsProcessed++; + if (prefsProcessed % 100000 == 0) { + log.info("Computed {} estimations", prefsProcessed); + } + } + } + log.info("Computed {} estimations, done.", prefsProcessed); + + double rmse = Math.sqrt(average.getAverage()); + log.info("RMSE {}", rmse); + + log.info("Estimating test preferences..."); + OutputStream out = null; + try { + out = new BufferedOutputStream(new FileOutputStream(resultFile)); + + for (Pair<PreferenceArray,long[]> testPair + : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) { + for (Preference testPref : testPair.getFirst()) { + double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(), + factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference()); + byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID()); + out.write(result); + } + } + } finally { + Closeables.close(out, false); + } + log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath()); + } + + static double estimatePreference(Factorization factorization, long userID, long itemID, float minPreference, + float maxPreference) throws NoSuchUserException, NoSuchItemException { + double[] userFeatures = factorization.getUserFeatures(userID); + double[] itemFeatures = factorization.getItemFeatures(itemID); + double estimate = 0; + for (int feature = 0; feature < userFeatures.length; feature++) { + estimate += userFeatures[feature] * itemFeatures[feature]; + } + if (estimate < minPreference) { + estimate = minPreference; + } else if (estimate > maxPreference) { + estimate = maxPreference; + } + return estimate; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java new file mode 100644 index 0000000..ce025a9 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +import java.io.File; +import java.io.IOException; +import java.util.Collection; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.similarity.AbstractItemSimilarity; +import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; + +final class HybridSimilarity extends AbstractItemSimilarity { + + private final ItemSimilarity cfSimilarity; + private final ItemSimilarity contentSimilarity; + + HybridSimilarity(DataModel dataModel, File dataFileDirectory) throws IOException { + super(dataModel); + cfSimilarity = new LogLikelihoodSimilarity(dataModel); + contentSimilarity = new TrackItemSimilarity(dataFileDirectory); + } + + @Override + public double itemSimilarity(long itemID1, long itemID2) throws TasteException { + return contentSimilarity.itemSimilarity(itemID1, itemID2) * cfSimilarity.itemSimilarity(itemID1, itemID2); + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { + double[] result = contentSimilarity.itemSimilarities(itemID1, itemID2s); + double[] multipliers = cfSimilarity.itemSimilarities(itemID1, itemID2s); + for (int i = 0; i < result.length; i++) { + result[i] *= multipliers[i]; + } + return result; + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + cfSimilarity.refresh(alreadyRefreshed); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java new file mode 100644 index 0000000..50fd35e --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +import org.apache.mahout.cf.taste.common.NoSuchItemException; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.TreeMap; +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicInteger; + +final class Track2Callable implements Callable<UserResult> { + + private static final Logger log = LoggerFactory.getLogger(Track2Callable.class); + private static final AtomicInteger COUNT = new AtomicInteger(); + + private final Recommender recommender; + private final PreferenceArray userTest; + + Track2Callable(Recommender recommender, PreferenceArray userTest) { + this.recommender = recommender; + this.userTest = userTest; + } + + @Override + public UserResult call() throws TasteException { + + int testSize = userTest.length(); + if (testSize != 6) { + throw new IllegalArgumentException("Expecting 6 items for user but got " + userTest); + } + long userID = userTest.get(0).getUserID(); + TreeMap<Double,Long> estimateToItemID = new TreeMap<>(Collections.reverseOrder()); + + for (int i = 0; i < testSize; i++) { + long itemID = userTest.getItemID(i); + double estimate; + try { + estimate = recommender.estimatePreference(userID, itemID); + } catch (NoSuchItemException nsie) { + // OK in the sample data provided before the contest, should never happen otherwise + log.warn("Unknown item {}; OK unless this is the real contest data", itemID); + continue; + } + + if (!Double.isNaN(estimate)) { + estimateToItemID.put(estimate, itemID); + } + } + + Collection<Long> itemIDs = estimateToItemID.values(); + List<Long> topThree = new ArrayList<>(itemIDs); + if (topThree.size() > 3) { + topThree = topThree.subList(0, 3); + } else if (topThree.size() < 3) { + log.warn("Unable to recommend three items for {}", userID); + // Some NaNs - just guess at the rest then + Collection<Long> newItemIDs = new HashSet<>(3); + newItemIDs.addAll(itemIDs); + int i = 0; + while (i < testSize && newItemIDs.size() < 3) { + newItemIDs.add(userTest.getItemID(i)); + i++; + } + topThree = new ArrayList<>(newItemIDs); + } + if (topThree.size() != 3) { + throw new IllegalStateException(); + } + + boolean[] result = new boolean[testSize]; + for (int i = 0; i < testSize; i++) { + result[i] = topThree.contains(userTest.getItemID(i)); + } + + if (COUNT.incrementAndGet() % 1000 == 0) { + log.info("Completed {} users", COUNT.get()); + } + + return new UserResult(userID, result); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java new file mode 100644 index 0000000..185a00d --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +import java.io.File; +import java.io.IOException; +import java.util.Collection; +import java.util.List; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefItemBasedRecommender; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.recommender.IDRescorer; +import org.apache.mahout.cf.taste.recommender.RecommendedItem; +import org.apache.mahout.cf.taste.recommender.Recommender; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; + +public final class Track2Recommender implements Recommender { + + private final Recommender recommender; + + public Track2Recommender(DataModel dataModel, File dataFileDirectory) throws TasteException { + // Change this to whatever you like! + ItemSimilarity similarity; + try { + similarity = new HybridSimilarity(dataModel, dataFileDirectory); + } catch (IOException ioe) { + throw new TasteException(ioe); + } + recommender = new GenericBooleanPrefItemBasedRecommender(dataModel, similarity); + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException { + return recommender.recommend(userID, howMany); + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException { + return recommend(userID, howMany, null, includeKnownItems); + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException { + return recommender.recommend(userID, howMany, rescorer, false); + } + + @Override + public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems) + throws TasteException { + return recommender.recommend(userID, howMany, rescorer, includeKnownItems); + } + + @Override + public float estimatePreference(long userID, long itemID) throws TasteException { + return recommender.estimatePreference(userID, itemID); + } + + @Override + public void setPreference(long userID, long itemID, float value) throws TasteException { + recommender.setPreference(userID, itemID, value); + } + + @Override + public void removePreference(long userID, long itemID) throws TasteException { + recommender.removePreference(userID, itemID); + } + + @Override + public DataModel getDataModel() { + return recommender.getDataModel(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + recommender.refresh(alreadyRefreshed); + } + + @Override + public String toString() { + return "Track1Recommender[recommender:" + recommender + ']'; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java new file mode 100644 index 0000000..09ade5d --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +import org.apache.mahout.cf.taste.common.TasteException; +import org.apache.mahout.cf.taste.eval.RecommenderBuilder; +import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.recommender.Recommender; + +final class Track2RecommenderBuilder implements RecommenderBuilder { + + @Override + public Recommender buildRecommender(DataModel dataModel) throws TasteException { + return new Track2Recommender(dataModel, ((KDDCupDataModel) dataModel).getDataFileDirectory()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java new file mode 100644 index 0000000..3cbb61c --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable; +import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; +import org.apache.mahout.cf.taste.model.PreferenceArray; +import org.apache.mahout.common.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +/** + * <p>Runs "track 2" of the KDD Cup competition using whatever recommender is inside {@link Track2Recommender} + * and attempts to output the result in the correct contest format.</p> + * + * <p>Run as: {@code Track2Runner [track 2 data file directory] [output file]}</p> + */ +public final class Track2Runner { + + private static final Logger log = LoggerFactory.getLogger(Track2Runner.class); + + private Track2Runner() { + } + + public static void main(String[] args) throws Exception { + + File dataFileDirectory = new File(args[0]); + if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) { + throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory); + } + + long start = System.currentTimeMillis(); + + KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory)); + Track2Recommender recommender = new Track2Recommender(model, dataFileDirectory); + + long end = System.currentTimeMillis(); + log.info("Loaded model in {}s", (end - start) / 1000); + start = end; + + Collection<Track2Callable> callables = new ArrayList<>(); + for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) { + PreferenceArray userTest = tests.getFirst(); + callables.add(new Track2Callable(recommender, userTest)); + } + + int cores = Runtime.getRuntime().availableProcessors(); + log.info("Running on {} cores", cores); + ExecutorService executor = Executors.newFixedThreadPool(cores); + List<Future<UserResult>> futures = executor.invokeAll(callables); + executor.shutdown(); + + end = System.currentTimeMillis(); + log.info("Ran recommendations in {}s", (end - start) / 1000); + start = end; + + try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){ + long lastUserID = Long.MIN_VALUE; + for (Future<UserResult> future : futures) { + UserResult result = future.get(); + long userID = result.getUserID(); + if (userID <= lastUserID) { + throw new IllegalStateException(); + } + lastUserID = userID; + out.write(result.getResultBytes()); + } + } + + end = System.currentTimeMillis(); + log.info("Wrote output in {}s", (end - start) / 1000); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java new file mode 100644 index 0000000..abd15f8 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +import java.util.regex.Pattern; + +import org.apache.mahout.cf.taste.impl.common.FastIDSet; + +final class TrackData { + + private static final Pattern PIPE = Pattern.compile("\\|"); + private static final String NO_VALUE = "None"; + static final long NO_VALUE_ID = Long.MIN_VALUE; + private static final FastIDSet NO_GENRES = new FastIDSet(); + + private final long trackID; + private final long albumID; + private final long artistID; + private final FastIDSet genreIDs; + + TrackData(CharSequence line) { + String[] tokens = PIPE.split(line); + trackID = Long.parseLong(tokens[0]); + albumID = parse(tokens[1]); + artistID = parse(tokens[2]); + if (tokens.length > 3) { + genreIDs = new FastIDSet(tokens.length - 3); + for (int i = 3; i < tokens.length; i++) { + genreIDs.add(Long.parseLong(tokens[i])); + } + } else { + genreIDs = NO_GENRES; + } + } + + private static long parse(String value) { + return NO_VALUE.equals(value) ? NO_VALUE_ID : Long.parseLong(value); + } + + public long getTrackID() { + return trackID; + } + + public long getAlbumID() { + return albumID; + } + + public long getArtistID() { + return artistID; + } + + public FastIDSet getGenreIDs() { + return genreIDs; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java new file mode 100644 index 0000000..3012a84 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +import java.io.File; +import java.io.IOException; +import java.util.Collection; + +import org.apache.mahout.cf.taste.common.Refreshable; +import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; +import org.apache.mahout.cf.taste.impl.common.FastByIDMap; +import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; +import org.apache.mahout.cf.taste.similarity.ItemSimilarity; +import org.apache.mahout.common.iterator.FileLineIterable; + +final class TrackItemSimilarity implements ItemSimilarity { + + private final FastByIDMap<TrackData> trackData; + + TrackItemSimilarity(File dataFileDirectory) throws IOException { + trackData = new FastByIDMap<>(); + for (String line : new FileLineIterable(KDDCupDataModel.getTrackFile(dataFileDirectory))) { + TrackData trackDatum = new TrackData(line); + trackData.put(trackDatum.getTrackID(), trackDatum); + } + } + + @Override + public double itemSimilarity(long itemID1, long itemID2) { + if (itemID1 == itemID2) { + return 1.0; + } + TrackData data1 = trackData.get(itemID1); + TrackData data2 = trackData.get(itemID2); + if (data1 == null || data2 == null) { + return 0.0; + } + + // Arbitrarily decide that same album means "very similar" + if (data1.getAlbumID() != TrackData.NO_VALUE_ID && data1.getAlbumID() == data2.getAlbumID()) { + return 0.9; + } + // ... and same artist means "fairly similar" + if (data1.getArtistID() != TrackData.NO_VALUE_ID && data1.getArtistID() == data2.getArtistID()) { + return 0.7; + } + + // Tanimoto coefficient similarity based on genre, but maximum value of 0.25 + FastIDSet genres1 = data1.getGenreIDs(); + FastIDSet genres2 = data2.getGenreIDs(); + if (genres1 == null || genres2 == null) { + return 0.0; + } + int intersectionSize = genres1.intersectionSize(genres2); + if (intersectionSize == 0) { + return 0.0; + } + int unionSize = genres1.size() + genres2.size() - intersectionSize; + return intersectionSize / (4.0 * unionSize); + } + + @Override + public double[] itemSimilarities(long itemID1, long[] itemID2s) { + int length = itemID2s.length; + double[] result = new double[length]; + for (int i = 0; i < length; i++) { + result[i] = itemSimilarity(itemID1, itemID2s[i]); + } + return result; + } + + @Override + public long[] allSimilarItemIDs(long itemID) { + FastIDSet allSimilarItemIDs = new FastIDSet(); + LongPrimitiveIterator allItemIDs = trackData.keySetIterator(); + while (allItemIDs.hasNext()) { + long possiblySimilarItemID = allItemIDs.nextLong(); + if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) { + allSimilarItemIDs.add(possiblySimilarItemID); + } + } + return allSimilarItemIDs.toArray(); + } + + @Override + public void refresh(Collection<Refreshable> alreadyRefreshed) { + // do nothing + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java new file mode 100644 index 0000000..e554d10 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.example.kddcup.track2; + +final class UserResult { + + private final long userID; + private final byte[] resultBytes; + + UserResult(long userID, boolean[] result) { + + this.userID = userID; + + int trueCount = 0; + for (boolean b : result) { + if (b) { + trueCount++; + } + } + if (trueCount != 3) { + throw new IllegalStateException(); + } + + resultBytes = new byte[result.length]; + for (int i = 0; i < result.length; i++) { + resultBytes[i] = (byte) (result[i] ? '1' : '0'); + } + } + + public long getUserID() { + return userID; + } + + public byte[] getResultBytes() { + return resultBytes; + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java new file mode 100644 index 0000000..22f122e --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java @@ -0,0 +1,140 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.hadoop.example.als.netflix; + +import com.google.common.base.Preconditions; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.cf.taste.impl.model.GenericPreference; +import org.apache.mahout.cf.taste.model.Preference; +import org.apache.mahout.common.iterator.FileLineIterable; +import org.apache.mahout.common.iterator.FileLineIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +/** converts the raw files provided by netflix to an appropriate input format */ +public final class NetflixDatasetConverter { + + private static final Logger log = LoggerFactory.getLogger(NetflixDatasetConverter.class); + + private static final Pattern SEPARATOR = Pattern.compile(","); + private static final String MOVIE_DENOTER = ":"; + private static final String TAB = "\t"; + private static final String NEWLINE = "\n"; + + private NetflixDatasetConverter() { + } + + public static void main(String[] args) throws IOException { + + if (args.length != 4) { + System.err.println("Usage: NetflixDatasetConverter /path/to/training_set/ /path/to/qualifying.txt " + + "/path/to/judging.txt /path/to/destination"); + return; + } + + String trainingDataDir = args[0]; + String qualifyingTxt = args[1]; + String judgingTxt = args[2]; + Path outputPath = new Path(args[3]); + + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(outputPath.toUri(), conf); + + Preconditions.checkArgument(trainingDataDir != null, "Training Data location needs to be specified"); + log.info("Creating training set at {}/trainingSet/ratings.tsv ...", outputPath); + try (BufferedWriter writer = + new BufferedWriter( + new OutputStreamWriter( + fs.create(new Path(outputPath, "trainingSet/ratings.tsv")), Charsets.UTF_8))){ + + int ratingsProcessed = 0; + for (File movieRatings : new File(trainingDataDir).listFiles()) { + try (FileLineIterator lines = new FileLineIterator(movieRatings)) { + boolean firstLineRead = false; + String movieID = null; + while (lines.hasNext()) { + String line = lines.next(); + if (firstLineRead) { + String[] tokens = SEPARATOR.split(line); + String userID = tokens[0]; + String rating = tokens[1]; + writer.write(userID + TAB + movieID + TAB + rating + NEWLINE); + ratingsProcessed++; + if (ratingsProcessed % 1000000 == 0) { + log.info("{} ratings processed...", ratingsProcessed); + } + } else { + movieID = line.replaceAll(MOVIE_DENOTER, ""); + firstLineRead = true; + } + } + } + + } + log.info("{} ratings processed. done.", ratingsProcessed); + } + + log.info("Reading probes..."); + List<Preference> probes = new ArrayList<>(2817131); + long currentMovieID = -1; + for (String line : new FileLineIterable(new File(qualifyingTxt))) { + if (line.contains(MOVIE_DENOTER)) { + currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, "")); + } else { + long userID = Long.parseLong(SEPARATOR.split(line)[0]); + probes.add(new GenericPreference(userID, currentMovieID, 0)); + } + } + log.info("{} probes read...", probes.size()); + + log.info("Reading ratings, creating probe set at {}/probeSet/ratings.tsv ...", outputPath); + try (BufferedWriter writer = + new BufferedWriter(new OutputStreamWriter( + fs.create(new Path(outputPath, "probeSet/ratings.tsv")), Charsets.UTF_8))){ + int ratingsProcessed = 0; + for (String line : new FileLineIterable(new File(judgingTxt))) { + if (line.contains(MOVIE_DENOTER)) { + currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, "")); + } else { + float rating = Float.parseFloat(SEPARATOR.split(line)[0]); + Preference pref = probes.get(ratingsProcessed); + Preconditions.checkState(pref.getItemID() == currentMovieID); + ratingsProcessed++; + writer.write(pref.getUserID() + TAB + pref.getItemID() + TAB + rating + NEWLINE); + if (ratingsProcessed % 1000000 == 0) { + log.info("{} ratings processed...", ratingsProcessed); + } + } + } + log.info("{} ratings processed. done.", ratingsProcessed); + } + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java new file mode 100644 index 0000000..8021d00 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity.precompute.example; + +import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender; +import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity; +import org.apache.mahout.cf.taste.impl.similarity.precompute.FileSimilarItemsWriter; +import org.apache.mahout.cf.taste.impl.similarity.precompute.MultithreadedBatchItemSimilarities; +import org.apache.mahout.cf.taste.model.DataModel; +import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender; +import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities; + +import java.io.File; + +/** + * Example that precomputes all item similarities of the Movielens1M dataset + * + * Usage: download movielens1M from http://www.grouplens.org/node/73 , unzip it and invoke this code with the path + * to the ratings.dat file as argument + * + */ +public final class BatchItemSimilaritiesGroupLens { + + private BatchItemSimilaritiesGroupLens() {} + + public static void main(String[] args) throws Exception { + + if (args.length != 1) { + System.err.println("Need path to ratings.dat of the movielens1M dataset as argument!"); + System.exit(-1); + } + + File resultFile = new File(System.getProperty("java.io.tmpdir"), "similarities.csv"); + if (resultFile.exists()) { + resultFile.delete(); + } + + DataModel dataModel = new GroupLensDataModel(new File(args[0])); + ItemBasedRecommender recommender = new GenericItemBasedRecommender(dataModel, + new LogLikelihoodSimilarity(dataModel)); + BatchItemSimilarities batch = new MultithreadedBatchItemSimilarities(recommender, 5); + + int numSimilarities = batch.computeItemSimilarities(Runtime.getRuntime().availableProcessors(), 1, + new FileSimilarItemsWriter(resultFile)); + + System.out.println("Computed " + numSimilarities + " similarities for " + dataModel.getNumItems() + " items " + + "and saved them to " + resultFile.getAbsolutePath()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java new file mode 100644 index 0000000..7ee9b17 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.similarity.precompute.example; + +import com.google.common.io.Files; +import com.google.common.io.InputSupplier; +import com.google.common.io.Resources; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.net.URL; +import java.util.regex.Pattern; +import org.apache.commons.io.Charsets; +import org.apache.mahout.cf.taste.impl.model.file.FileDataModel; +import org.apache.mahout.common.iterator.FileLineIterable; + +public final class GroupLensDataModel extends FileDataModel { + + private static final String COLON_DELIMTER = "::"; + private static final Pattern COLON_DELIMITER_PATTERN = Pattern.compile(COLON_DELIMTER); + + public GroupLensDataModel() throws IOException { + this(readResourceToTempFile("/org/apache/mahout/cf/taste/example/grouplens/ratings.dat")); + } + + /** + * @param ratingsFile GroupLens ratings.dat file in its native format + * @throws IOException if an error occurs while reading or writing files + */ + public GroupLensDataModel(File ratingsFile) throws IOException { + super(convertGLFile(ratingsFile)); + } + + private static File convertGLFile(File originalFile) throws IOException { + // Now translate the file; remove commas, then convert "::" delimiter to comma + File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "ratings.txt"); + if (resultFile.exists()) { + resultFile.delete(); + } + try (Writer writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8)){ + for (String line : new FileLineIterable(originalFile, false)) { + int lastDelimiterStart = line.lastIndexOf(COLON_DELIMTER); + if (lastDelimiterStart < 0) { + throw new IOException("Unexpected input format on line: " + line); + } + String subLine = line.substring(0, lastDelimiterStart); + String convertedLine = COLON_DELIMITER_PATTERN.matcher(subLine).replaceAll(","); + writer.write(convertedLine); + writer.write('\n'); + } + } catch (IOException ioe) { + resultFile.delete(); + throw ioe; + } + return resultFile; + } + + public static File readResourceToTempFile(String resourceName) throws IOException { + InputSupplier<? extends InputStream> inSupplier; + try { + URL resourceURL = Resources.getResource(GroupLensDataModel.class, resourceName); + inSupplier = Resources.newInputStreamSupplier(resourceURL); + } catch (IllegalArgumentException iae) { + File resourceFile = new File("src/main/java" + resourceName); + inSupplier = Files.newInputStreamSupplier(resourceFile); + } + File tempFile = File.createTempFile("taste", null); + tempFile.deleteOnExit(); + Files.copy(inSupplier, tempFile); + return tempFile; + } + + @Override + public String toString() { + return "GroupLensDataModel"; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java new file mode 100644 index 0000000..5cec51c --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java @@ -0,0 +1,128 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier; + +import com.google.common.collect.ConcurrentHashMultiset; +import com.google.common.collect.Multiset; +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.apache.commons.io.Charsets; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.TokenStream; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; +import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; +import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.text.SimpleDateFormat; +import java.util.Collection; +import java.util.Date; +import java.util.Locale; +import java.util.Random; + +public final class NewsgroupHelper { + + private static final SimpleDateFormat[] DATE_FORMATS = { + new SimpleDateFormat("", Locale.ENGLISH), + new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH), + new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH) + }; + + public static final int FEATURES = 10000; + // 1997-01-15 00:01:00 GMT + private static final long DATE_REFERENCE = 853286460; + private static final long MONTH = 30 * 24 * 3600; + private static final long WEEK = 7 * 24 * 3600; + + private final Random rand = RandomUtils.getRandom(); + private final Analyzer analyzer = new StandardAnalyzer(); + private final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); + private final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); + + public FeatureVectorEncoder getEncoder() { + return encoder; + } + + public FeatureVectorEncoder getBias() { + return bias; + } + + public Random getRandom() { + return rand; + } + + public Vector encodeFeatureVector(File file, int actual, int leakType, Multiset<String> overallCounts) + throws IOException { + long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * rand.nextDouble())); + Multiset<String> words = ConcurrentHashMultiset.create(); + + try (BufferedReader reader = Files.newReader(file, Charsets.UTF_8)) { + String line = reader.readLine(); + Reader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date))); + countWords(analyzer, words, dateString, overallCounts); + while (line != null && !line.isEmpty()) { + boolean countHeader = ( + line.startsWith("From:") || line.startsWith("Subject:") + || line.startsWith("Keywords:") || line.startsWith("Summary:")) && leakType < 6; + do { + Reader in = new StringReader(line); + if (countHeader) { + countWords(analyzer, words, in, overallCounts); + } + line = reader.readLine(); + } while (line != null && line.startsWith(" ")); + } + if (leakType < 3) { + countWords(analyzer, words, reader, overallCounts); + } + } + + Vector v = new RandomAccessSparseVector(FEATURES); + bias.addToVector("", 1, v); + for (String word : words.elementSet()) { + encoder.addToVector(word, Math.log1p(words.count(word)), v); + } + + return v; + } + + public static void countWords(Analyzer analyzer, + Collection<String> words, + Reader in, + Multiset<String> overallCounts) throws IOException { + TokenStream ts = analyzer.tokenStream("text", in); + ts.addAttribute(CharTermAttribute.class); + ts.reset(); + while (ts.incrementToken()) { + String s = ts.getAttribute(CharTermAttribute.class).toString(); + words.add(s); + } + overallCounts.addAll(words); + ts.end(); + Closeables.close(ts, true); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java new file mode 100644 index 0000000..16e9d80 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.email; + +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.util.Locale; +import java.util.regex.Pattern; + +/** + * Convert the labels created by the {@link org.apache.mahout.utils.email.MailProcessor} to one consumable + * by the classifiers + */ +public class PrepEmailMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> { + + private static final Pattern DASH_DOT = Pattern.compile("-|\\."); + private static final Pattern SLASH = Pattern.compile("\\/"); + + private boolean useListName = false; //if true, use the project name and the list name in label creation + @Override + protected void setup(Context context) throws IOException, InterruptedException { + useListName = Boolean.parseBoolean(context.getConfiguration().get(PrepEmailVectorsDriver.USE_LIST_NAME)); + } + + @Override + protected void map(WritableComparable<?> key, VectorWritable value, Context context) + throws IOException, InterruptedException { + String input = key.toString(); + ///Example: /cocoon.apache.org/dev/200307.gz/001401c3414f$8394e160$1e01a8c0@WRPO + String[] splits = SLASH.split(input); + //we need the first two splits; + if (splits.length >= 3) { + StringBuilder bldr = new StringBuilder(); + bldr.append(escape(splits[1])); + if (useListName) { + bldr.append('_').append(escape(splits[2])); + } + context.write(new Text(bldr.toString()), value); + } + + } + + private static String escape(CharSequence value) { + return DASH_DOT.matcher(value).replaceAll("_").toLowerCase(Locale.ENGLISH); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java new file mode 100644 index 0000000..da6e613 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.email; + +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.util.Iterator; + +public class PrepEmailReducer extends Reducer<Text, VectorWritable, Text, VectorWritable> { + + private long maxItemsPerLabel = 10000; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + maxItemsPerLabel = Long.parseLong(context.getConfiguration().get(PrepEmailVectorsDriver.ITEMS_PER_CLASS)); + } + + @Override + protected void reduce(Text key, Iterable<VectorWritable> values, Context context) + throws IOException, InterruptedException { + //TODO: support randomization? Likely not needed due to the SplitInput utility which does random selection + long i = 0; + Iterator<VectorWritable> iterator = values.iterator(); + while (i < maxItemsPerLabel && iterator.hasNext()) { + context.write(key, iterator.next()); + i++; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java new file mode 100644 index 0000000..8fba739 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.email; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.math.VectorWritable; + +import java.util.List; +import java.util.Map; + +/** + * Convert the labels generated by {@link org.apache.mahout.text.SequenceFilesFromMailArchives} and + * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} to ones consumable by the classifiers. We do this + * here b/c if it is done in the creation of sparse vectors, the Reducer collapses all the vectors. + */ +public class PrepEmailVectorsDriver extends AbstractJob { + + public static final String ITEMS_PER_CLASS = "itemsPerClass"; + public static final String USE_LIST_NAME = "USE_LIST_NAME"; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new PrepEmailVectorsDriver(), args); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.overwriteOption().create()); + addOption("maxItemsPerLabel", "mipl", "The maximum number of items per label. Can be useful for making the " + + "training sets the same size", String.valueOf(100000)); + addOption(buildOption("useListName", "ul", "Use the name of the list as part of the label. If not set, then " + + "just use the project name", false, false, "false")); + Map<String,List<String>> parsedArgs = parseArguments(args); + if (parsedArgs == null) { + return -1; + } + + Path input = getInputPath(); + Path output = getOutputPath(); + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + Job convertJob = prepareJob(input, output, SequenceFileInputFormat.class, PrepEmailMapper.class, Text.class, + VectorWritable.class, PrepEmailReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class); + convertJob.getConfiguration().set(ITEMS_PER_CLASS, getOption("maxItemsPerLabel")); + convertJob.getConfiguration().set(USE_LIST_NAME, String.valueOf(hasOption("useListName"))); + + boolean succeeded = convertJob.waitForCompletion(true); + return succeeded ? 0 : -1; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java new file mode 100644 index 0000000..9c0ef56 --- /dev/null +++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java @@ -0,0 +1,277 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import com.google.common.io.Resources; +import org.apache.commons.io.Charsets; +import org.apache.mahout.math.Matrix; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URL; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * This class implements a sample program that uses a pre-tagged training data + * set to train an HMM model as a POS tagger. The training data is automatically + * downloaded from the following URL: + * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then + * trains an HMM Model using supervised learning and tests the model on the + * following test data set: + * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further + * details regarding the data files can be found at + * http://flexcrfs.sourceforge.net/#Case_Study + */ +public final class PosTagger { + + private static final Logger log = LoggerFactory.getLogger(PosTagger.class); + + private static final Pattern SPACE = Pattern.compile(" "); + private static final Pattern SPACES = Pattern.compile("[ ]+"); + + /** + * No public constructors for utility classes. + */ + private PosTagger() { + // nothing to do here really. + } + + /** + * Model trained in the example. + */ + private static HmmModel taggingModel; + + /** + * Map for storing the IDs for the POS tags (hidden states) + */ + private static Map<String, Integer> tagIDs; + + /** + * Counter for the next assigned POS tag ID The value of 0 is reserved for + * "unknown POS tag" + */ + private static int nextTagId; + + /** + * Map for storing the IDs for observed words (observed states) + */ + private static Map<String, Integer> wordIDs; + + /** + * Counter for the next assigned word ID The value of 0 is reserved for + * "unknown word" + */ + private static int nextWordId = 1; // 0 is reserved for "unknown word" + + /** + * Used for storing a list of POS tags of read sentences. + */ + private static List<int[]> hiddenSequences; + + /** + * Used for storing a list of word tags of read sentences. + */ + private static List<int[]> observedSequences; + + /** + * number of read lines + */ + private static int readLines; + + /** + * Given an URL, this function fetches the data file, parses it, assigns POS + * Tag/word IDs and fills the hiddenSequences/observedSequences lists with + * data from those files. The data is expected to be in the following format + * (one word per line): word pos-tag np-tag sentences are closed with the . + * pos tag + * + * @param url Where the data file is stored + * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for + * training data, not needed for test data) + * @throws IOException in case data file cannot be read. + */ + private static void readFromURL(String url, boolean assignIDs) throws IOException { + // initialize the data structure + hiddenSequences = new LinkedList<>(); + observedSequences = new LinkedList<>(); + readLines = 0; + + // now read line by line of the input file + List<Integer> observedSequence = new LinkedList<>(); + List<Integer> hiddenSequence = new LinkedList<>(); + + for (String line :Resources.readLines(new URL(url), Charsets.UTF_8)) { + if (line.isEmpty()) { + // new sentence starts + int[] observedSequenceArray = new int[observedSequence.size()]; + int[] hiddenSequenceArray = new int[hiddenSequence.size()]; + for (int i = 0; i < observedSequence.size(); ++i) { + observedSequenceArray[i] = observedSequence.get(i); + hiddenSequenceArray[i] = hiddenSequence.get(i); + } + // now register those arrays + hiddenSequences.add(hiddenSequenceArray); + observedSequences.add(observedSequenceArray); + // and reset the linked lists + observedSequence.clear(); + hiddenSequence.clear(); + continue; + } + readLines++; + // we expect the format [word] [POS tag] [NP tag] + String[] tags = SPACE.split(line); + // when analyzing the training set, assign IDs + if (assignIDs) { + if (!wordIDs.containsKey(tags[0])) { + wordIDs.put(tags[0], nextWordId++); + } + if (!tagIDs.containsKey(tags[1])) { + tagIDs.put(tags[1], nextTagId++); + } + } + // determine the IDs + Integer wordID = wordIDs.get(tags[0]); + Integer tagID = tagIDs.get(tags[1]); + // now construct the current sequence + if (wordID == null) { + observedSequence.add(0); + } else { + observedSequence.add(wordID); + } + + if (tagID == null) { + hiddenSequence.add(0); + } else { + hiddenSequence.add(tagID); + } + } + + // if there is still something in the pipe, register it + if (!observedSequence.isEmpty()) { + int[] observedSequenceArray = new int[observedSequence.size()]; + int[] hiddenSequenceArray = new int[hiddenSequence.size()]; + for (int i = 0; i < observedSequence.size(); ++i) { + observedSequenceArray[i] = observedSequence.get(i); + hiddenSequenceArray[i] = hiddenSequence.get(i); + } + // now register those arrays + hiddenSequences.add(hiddenSequenceArray); + observedSequences.add(observedSequenceArray); + } + } + + private static void trainModel(String trainingURL) throws IOException { + tagIDs = new HashMap<>(44); // we expect 44 distinct tags + wordIDs = new HashMap<>(19122); // we expect 19122 + // distinct words + log.info("Reading and parsing training data file from URL: {}", trainingURL); + long start = System.currentTimeMillis(); + readFromURL(trainingURL, true); + long end = System.currentTimeMillis(); + double duration = (end - start) / 1000.0; + log.info("Parsing done in {} seconds!", duration); + log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.", + readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1); + start = System.currentTimeMillis(); + taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId, + hiddenSequences, observedSequences, 0.05); + // we have to adjust the model a bit, + // since we assume a higher probability that a given unknown word is NNP + // than anything else + Matrix emissions = taggingModel.getEmissionMatrix(); + for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) { + emissions.setQuick(i, 0, 0.1 / taggingModel.getNrOfHiddenStates()); + } + int nnptag = tagIDs.get("NNP"); + emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates()); + // re-normalize the emission probabilities + HmmUtils.normalizeModel(taggingModel); + // now register the names + taggingModel.registerHiddenStateNames(tagIDs); + taggingModel.registerOutputStateNames(wordIDs); + end = System.currentTimeMillis(); + duration = (end - start) / 1000.0; + log.info("Trained HMM models in {} seconds!", duration); + } + + private static void testModel(String testingURL) throws IOException { + log.info("Reading and parsing test data file from URL: {}", testingURL); + long start = System.currentTimeMillis(); + readFromURL(testingURL, false); + long end = System.currentTimeMillis(); + double duration = (end - start) / 1000.0; + log.info("Parsing done in {} seconds!", duration); + log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size()); + + start = System.currentTimeMillis(); + int errorCount = 0; + int totalCount = 0; + for (int i = 0; i < observedSequences.size(); ++i) { + // fetch the viterbi path as the POS tag for this observed sequence + int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false); + // compare with the expected + int[] posExpected = hiddenSequences.get(i); + for (int j = 0; j < posExpected.length; ++j) { + totalCount++; + if (posEstimate[j] != posExpected[j]) { + errorCount++; + } + } + } + end = System.currentTimeMillis(); + duration = (end - start) / 1000.0; + log.info("POS tagged test file in {} seconds!", duration); + double errorRate = (double) errorCount / totalCount; + log.info("Tagged the test file with an error rate of: {}", errorRate); + } + + private static List<String> tagSentence(String sentence) { + // first, we need to isolate all punctuation characters, so that they + // can be recognized + sentence = sentence.replaceAll("[,.!?:;\"]", " $0 "); + sentence = sentence.replaceAll("''", " '' "); + // now we tokenize the sentence + String[] tokens = SPACES.split(sentence); + // now generate the observed sequence + int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0); + // POS tag this observedSequence + int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false); + // and now decode the tag names + return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null); + } + + public static void main(String[] args) throws IOException { + // generate the model from URL + trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt"); + testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt"); + // tag an exemplary sentence + String test = "McDonalds is a huge company with many employees ."; + String[] testWords = SPACE.split(test); + List<String> posTags = tagSentence(test); + for (int i = 0; i < posTags.size(); ++i) { + log.info("{}[{}]", testWords[i], posTags.get(i)); + } + } + +}
