http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java deleted file mode 100644 index a99d54c..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java +++ /dev/null @@ -1,265 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track1.svd; - -import org.apache.mahout.cf.taste.common.Refreshable; -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.impl.common.FastByIDMap; -import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; -import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; -import org.apache.mahout.cf.taste.impl.common.RunningAverage; -import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization; -import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer; -import org.apache.mahout.cf.taste.model.DataModel; -import org.apache.mahout.cf.taste.model.Preference; -import org.apache.mahout.common.RandomUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Collection; -import java.util.Random; - -/** - * {@link Factorizer} based on Simon Funk's famous article <a href="http://sifter.org/~simon/journal/20061211.html"> - * "Netflix Update: Try this at home"</a>. - * - * Attempts to be as memory efficient as possible, only iterating once through the - * {@link FactorizablePreferences} or {@link DataModel} while copying everything to primitive arrays. - * Learning works in place on these datastructures after that. - */ -public class ParallelArraysSGDFactorizer implements Factorizer { - - public static final double DEFAULT_LEARNING_RATE = 0.005; - public static final double DEFAULT_PREVENT_OVERFITTING = 0.02; - public static final double DEFAULT_RANDOM_NOISE = 0.005; - - private final int numFeatures; - private final int numIterations; - private final float minPreference; - private final float maxPreference; - - private final Random random; - private final double learningRate; - private final double preventOverfitting; - - private final FastByIDMap<Integer> userIDMapping; - private final FastByIDMap<Integer> itemIDMapping; - - private final double[][] userFeatures; - private final double[][] itemFeatures; - - private final int[] userIndexes; - private final int[] itemIndexes; - private final float[] values; - - private final double defaultValue; - private final double interval; - private final double[] cachedEstimates; - - - private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class); - - public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) { - this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE, - DEFAULT_PREVENT_OVERFITTING, DEFAULT_RANDOM_NOISE); - } - - public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate, - double preventOverfitting, double randomNoise) { - this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting, - randomNoise); - } - - public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) { - this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING, - DEFAULT_RANDOM_NOISE); - } - - public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures, - int numIterations, double learningRate, double preventOverfitting, double randomNoise) { - - this.numFeatures = numFeatures; - this.numIterations = numIterations; - minPreference = factorizablePreferences.getMinPreference(); - maxPreference = factorizablePreferences.getMaxPreference(); - - this.random = RandomUtils.getRandom(); - this.learningRate = learningRate; - this.preventOverfitting = preventOverfitting; - - int numUsers = factorizablePreferences.numUsers(); - int numItems = factorizablePreferences.numItems(); - int numPrefs = factorizablePreferences.numPreferences(); - - log.info("Mapping {} users...", numUsers); - userIDMapping = new FastByIDMap<>(numUsers); - int index = 0; - LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs(); - while (userIterator.hasNext()) { - userIDMapping.put(userIterator.nextLong(), index++); - } - - log.info("Mapping {} items", numItems); - itemIDMapping = new FastByIDMap<>(numItems); - index = 0; - LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs(); - while (itemIterator.hasNext()) { - itemIDMapping.put(itemIterator.nextLong(), index++); - } - - this.userIndexes = new int[numPrefs]; - this.itemIndexes = new int[numPrefs]; - this.values = new float[numPrefs]; - this.cachedEstimates = new double[numPrefs]; - - index = 0; - log.info("Loading {} preferences into memory", numPrefs); - RunningAverage average = new FullRunningAverage(); - for (Preference preference : factorizablePreferences.getPreferences()) { - userIndexes[index] = userIDMapping.get(preference.getUserID()); - itemIndexes[index] = itemIDMapping.get(preference.getItemID()); - values[index] = preference.getValue(); - cachedEstimates[index] = 0; - - average.addDatum(preference.getValue()); - - index++; - if (index % 1000000 == 0) { - log.info("Processed {} preferences", index); - } - } - log.info("Processed {} preferences, done.", index); - - double averagePreference = average.getAverage(); - log.info("Average preference value is {}", averagePreference); - - double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference(); - defaultValue = Math.sqrt((averagePreference - prefInterval * 0.1) / numFeatures); - interval = prefInterval * 0.1 / numFeatures; - - userFeatures = new double[numUsers][numFeatures]; - itemFeatures = new double[numItems][numFeatures]; - - log.info("Initializing feature vectors..."); - for (int feature = 0; feature < numFeatures; feature++) { - for (int userIndex = 0; userIndex < numUsers; userIndex++) { - userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise; - } - for (int itemIndex = 0; itemIndex < numItems; itemIndex++) { - itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise; - } - } - } - - @Override - public Factorization factorize() throws TasteException { - for (int feature = 0; feature < numFeatures; feature++) { - log.info("Shuffling preferences..."); - shufflePreferences(); - log.info("Starting training of feature {} ...", feature); - for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) { - if (currentIteration == numIterations - 1) { - double rmse = trainingIterationWithRmse(feature); - log.info("Finished training feature {} with RMSE {}", feature, rmse); - } else { - trainingIteration(feature); - } - } - if (feature < numFeatures - 1) { - log.info("Updating cache..."); - for (int index = 0; index < userIndexes.length; index++) { - cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature, cachedEstimates[index], - false); - } - } - } - log.info("Factorization done"); - return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures); - } - - private void trainingIteration(int feature) { - for (int index = 0; index < userIndexes.length; index++) { - train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]); - } - } - - private double trainingIterationWithRmse(int feature) { - double rmse = 0.0; - for (int index = 0; index < userIndexes.length; index++) { - double error = train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]); - rmse += error * error; - } - return Math.sqrt(rmse / userIndexes.length); - } - - private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) { - double sum = cachedEstimate; - sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature]; - if (trailing) { - sum += (numFeatures - feature - 1) * (defaultValue + interval) * (defaultValue + interval); - if (sum > maxPreference) { - sum = maxPreference; - } else if (sum < minPreference) { - sum = minPreference; - } - } - return sum; - } - - public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) { - double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true); - double[] userVector = userFeatures[userIndex]; - double[] itemVector = itemFeatures[itemIndex]; - - userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting * userVector[feature]); - itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting * itemVector[feature]); - - return error; - } - - protected void shufflePreferences() { - /* Durstenfeld shuffle */ - for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) { - int swapPos = random.nextInt(currentPos + 1); - swapPreferences(currentPos, swapPos); - } - } - - private void swapPreferences(int posA, int posB) { - int tmpUserIndex = userIndexes[posA]; - int tmpItemIndex = itemIndexes[posA]; - float tmpValue = values[posA]; - double tmpEstimate = cachedEstimates[posA]; - - userIndexes[posA] = userIndexes[posB]; - itemIndexes[posA] = itemIndexes[posB]; - values[posA] = values[posB]; - cachedEstimates[posA] = cachedEstimates[posB]; - - userIndexes[posB] = tmpUserIndex; - itemIndexes[posB] = tmpItemIndex; - values[posB] = tmpValue; - cachedEstimates[posB] = tmpEstimate; - } - - @Override - public void refresh(Collection<Refreshable> alreadyRefreshed) { - // do nothing - } - -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java deleted file mode 100644 index 5cce02d..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java +++ /dev/null @@ -1,141 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track1.svd; - -import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.OutputStream; - -import com.google.common.io.Closeables; -import org.apache.mahout.cf.taste.common.NoSuchItemException; -import org.apache.mahout.cf.taste.common.NoSuchUserException; -import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable; -import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; -import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter; -import org.apache.mahout.cf.taste.impl.common.FullRunningAverage; -import org.apache.mahout.cf.taste.impl.common.RunningAverage; -import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization; -import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer; -import org.apache.mahout.cf.taste.model.Preference; -import org.apache.mahout.cf.taste.model.PreferenceArray; -import org.apache.mahout.common.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * run an SVD factorization of the KDD track1 data. - * - * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M - * - */ -public final class Track1SVDRunner { - - private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class); - - private Track1SVDRunner() { - } - - public static void main(String[] args) throws Exception { - - if (args.length != 2) { - System.err.println("Necessary arguments: <kddDataFileDirectory> <resultFile>"); - return; - } - - File dataFileDirectory = new File(args[0]); - if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) { - throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory); - } - - File resultFile = new File(args[1]); - - /* the knobs to turn */ - int numFeatures = 20; - int numIterations = 5; - double learningRate = 0.0001; - double preventOverfitting = 0.002; - double randomNoise = 0.0001; - - - KDDCupFactorizablePreferences factorizablePreferences = - new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory)); - - Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations, - learningRate, preventOverfitting, randomNoise); - - Factorization factorization = sgdFactorizer.factorize(); - - log.info("Estimating validation preferences..."); - int prefsProcessed = 0; - RunningAverage average = new FullRunningAverage(); - for (Pair<PreferenceArray,long[]> validationPair - : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) { - for (Preference validationPref : validationPair.getFirst()) { - double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(), - factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference()); - double error = validationPref.getValue() - estimate; - average.addDatum(error * error); - prefsProcessed++; - if (prefsProcessed % 100000 == 0) { - log.info("Computed {} estimations", prefsProcessed); - } - } - } - log.info("Computed {} estimations, done.", prefsProcessed); - - double rmse = Math.sqrt(average.getAverage()); - log.info("RMSE {}", rmse); - - log.info("Estimating test preferences..."); - OutputStream out = null; - try { - out = new BufferedOutputStream(new FileOutputStream(resultFile)); - - for (Pair<PreferenceArray,long[]> testPair - : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) { - for (Preference testPref : testPair.getFirst()) { - double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(), - factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference()); - byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID()); - out.write(result); - } - } - } finally { - Closeables.close(out, false); - } - log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath()); - } - - static double estimatePreference(Factorization factorization, long userID, long itemID, float minPreference, - float maxPreference) throws NoSuchUserException, NoSuchItemException { - double[] userFeatures = factorization.getUserFeatures(userID); - double[] itemFeatures = factorization.getItemFeatures(itemID); - double estimate = 0; - for (int feature = 0; feature < userFeatures.length; feature++) { - estimate += userFeatures[feature] * itemFeatures[feature]; - } - if (estimate < minPreference) { - estimate = minPreference; - } else if (estimate > maxPreference) { - estimate = maxPreference; - } - return estimate; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java deleted file mode 100644 index ce025a9..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -import java.io.File; -import java.io.IOException; -import java.util.Collection; - -import org.apache.mahout.cf.taste.common.Refreshable; -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.impl.similarity.AbstractItemSimilarity; -import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity; -import org.apache.mahout.cf.taste.model.DataModel; -import org.apache.mahout.cf.taste.similarity.ItemSimilarity; - -final class HybridSimilarity extends AbstractItemSimilarity { - - private final ItemSimilarity cfSimilarity; - private final ItemSimilarity contentSimilarity; - - HybridSimilarity(DataModel dataModel, File dataFileDirectory) throws IOException { - super(dataModel); - cfSimilarity = new LogLikelihoodSimilarity(dataModel); - contentSimilarity = new TrackItemSimilarity(dataFileDirectory); - } - - @Override - public double itemSimilarity(long itemID1, long itemID2) throws TasteException { - return contentSimilarity.itemSimilarity(itemID1, itemID2) * cfSimilarity.itemSimilarity(itemID1, itemID2); - } - - @Override - public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException { - double[] result = contentSimilarity.itemSimilarities(itemID1, itemID2s); - double[] multipliers = cfSimilarity.itemSimilarities(itemID1, itemID2s); - for (int i = 0; i < result.length; i++) { - result[i] *= multipliers[i]; - } - return result; - } - - @Override - public void refresh(Collection<Refreshable> alreadyRefreshed) { - cfSimilarity.refresh(alreadyRefreshed); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java deleted file mode 100644 index 50fd35e..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -import org.apache.mahout.cf.taste.common.NoSuchItemException; -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.model.PreferenceArray; -import org.apache.mahout.cf.taste.recommender.Recommender; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.TreeMap; -import java.util.concurrent.Callable; -import java.util.concurrent.atomic.AtomicInteger; - -final class Track2Callable implements Callable<UserResult> { - - private static final Logger log = LoggerFactory.getLogger(Track2Callable.class); - private static final AtomicInteger COUNT = new AtomicInteger(); - - private final Recommender recommender; - private final PreferenceArray userTest; - - Track2Callable(Recommender recommender, PreferenceArray userTest) { - this.recommender = recommender; - this.userTest = userTest; - } - - @Override - public UserResult call() throws TasteException { - - int testSize = userTest.length(); - if (testSize != 6) { - throw new IllegalArgumentException("Expecting 6 items for user but got " + userTest); - } - long userID = userTest.get(0).getUserID(); - TreeMap<Double,Long> estimateToItemID = new TreeMap<>(Collections.reverseOrder()); - - for (int i = 0; i < testSize; i++) { - long itemID = userTest.getItemID(i); - double estimate; - try { - estimate = recommender.estimatePreference(userID, itemID); - } catch (NoSuchItemException nsie) { - // OK in the sample data provided before the contest, should never happen otherwise - log.warn("Unknown item {}; OK unless this is the real contest data", itemID); - continue; - } - - if (!Double.isNaN(estimate)) { - estimateToItemID.put(estimate, itemID); - } - } - - Collection<Long> itemIDs = estimateToItemID.values(); - List<Long> topThree = new ArrayList<>(itemIDs); - if (topThree.size() > 3) { - topThree = topThree.subList(0, 3); - } else if (topThree.size() < 3) { - log.warn("Unable to recommend three items for {}", userID); - // Some NaNs - just guess at the rest then - Collection<Long> newItemIDs = new HashSet<>(3); - newItemIDs.addAll(itemIDs); - int i = 0; - while (i < testSize && newItemIDs.size() < 3) { - newItemIDs.add(userTest.getItemID(i)); - i++; - } - topThree = new ArrayList<>(newItemIDs); - } - if (topThree.size() != 3) { - throw new IllegalStateException(); - } - - boolean[] result = new boolean[testSize]; - for (int i = 0; i < testSize; i++) { - result[i] = topThree.contains(userTest.getItemID(i)); - } - - if (COUNT.incrementAndGet() % 1000 == 0) { - log.info("Completed {} users", COUNT.get()); - } - - return new UserResult(userID, result); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java deleted file mode 100644 index 185a00d..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -import java.io.File; -import java.io.IOException; -import java.util.Collection; -import java.util.List; - -import org.apache.mahout.cf.taste.common.Refreshable; -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefItemBasedRecommender; -import org.apache.mahout.cf.taste.model.DataModel; -import org.apache.mahout.cf.taste.recommender.IDRescorer; -import org.apache.mahout.cf.taste.recommender.RecommendedItem; -import org.apache.mahout.cf.taste.recommender.Recommender; -import org.apache.mahout.cf.taste.similarity.ItemSimilarity; - -public final class Track2Recommender implements Recommender { - - private final Recommender recommender; - - public Track2Recommender(DataModel dataModel, File dataFileDirectory) throws TasteException { - // Change this to whatever you like! - ItemSimilarity similarity; - try { - similarity = new HybridSimilarity(dataModel, dataFileDirectory); - } catch (IOException ioe) { - throw new TasteException(ioe); - } - recommender = new GenericBooleanPrefItemBasedRecommender(dataModel, similarity); - } - - @Override - public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException { - return recommender.recommend(userID, howMany); - } - - @Override - public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException { - return recommend(userID, howMany, null, includeKnownItems); - } - - @Override - public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException { - return recommender.recommend(userID, howMany, rescorer, false); - } - - @Override - public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems) - throws TasteException { - return recommender.recommend(userID, howMany, rescorer, includeKnownItems); - } - - @Override - public float estimatePreference(long userID, long itemID) throws TasteException { - return recommender.estimatePreference(userID, itemID); - } - - @Override - public void setPreference(long userID, long itemID, float value) throws TasteException { - recommender.setPreference(userID, itemID, value); - } - - @Override - public void removePreference(long userID, long itemID) throws TasteException { - recommender.removePreference(userID, itemID); - } - - @Override - public DataModel getDataModel() { - return recommender.getDataModel(); - } - - @Override - public void refresh(Collection<Refreshable> alreadyRefreshed) { - recommender.refresh(alreadyRefreshed); - } - - @Override - public String toString() { - return "Track1Recommender[recommender:" + recommender + ']'; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java deleted file mode 100644 index 09ade5d..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -import org.apache.mahout.cf.taste.common.TasteException; -import org.apache.mahout.cf.taste.eval.RecommenderBuilder; -import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; -import org.apache.mahout.cf.taste.model.DataModel; -import org.apache.mahout.cf.taste.recommender.Recommender; - -final class Track2RecommenderBuilder implements RecommenderBuilder { - - @Override - public Recommender buildRecommender(DataModel dataModel) throws TasteException { - return new Track2Recommender(dataModel, ((KDDCupDataModel) dataModel).getDataFileDirectory()); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java deleted file mode 100644 index 3cbb61c..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java +++ /dev/null @@ -1,100 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable; -import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; -import org.apache.mahout.cf.taste.model.PreferenceArray; -import org.apache.mahout.common.Pair; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.BufferedOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.io.OutputStream; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.Future; - -/** - * <p>Runs "track 2" of the KDD Cup competition using whatever recommender is inside {@link Track2Recommender} - * and attempts to output the result in the correct contest format.</p> - * - * <p>Run as: {@code Track2Runner [track 2 data file directory] [output file]}</p> - */ -public final class Track2Runner { - - private static final Logger log = LoggerFactory.getLogger(Track2Runner.class); - - private Track2Runner() { - } - - public static void main(String[] args) throws Exception { - - File dataFileDirectory = new File(args[0]); - if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) { - throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory); - } - - long start = System.currentTimeMillis(); - - KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory)); - Track2Recommender recommender = new Track2Recommender(model, dataFileDirectory); - - long end = System.currentTimeMillis(); - log.info("Loaded model in {}s", (end - start) / 1000); - start = end; - - Collection<Track2Callable> callables = new ArrayList<>(); - for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) { - PreferenceArray userTest = tests.getFirst(); - callables.add(new Track2Callable(recommender, userTest)); - } - - int cores = Runtime.getRuntime().availableProcessors(); - log.info("Running on {} cores", cores); - ExecutorService executor = Executors.newFixedThreadPool(cores); - List<Future<UserResult>> futures = executor.invokeAll(callables); - executor.shutdown(); - - end = System.currentTimeMillis(); - log.info("Ran recommendations in {}s", (end - start) / 1000); - start = end; - - try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){ - long lastUserID = Long.MIN_VALUE; - for (Future<UserResult> future : futures) { - UserResult result = future.get(); - long userID = result.getUserID(); - if (userID <= lastUserID) { - throw new IllegalStateException(); - } - lastUserID = userID; - out.write(result.getResultBytes()); - } - } - - end = System.currentTimeMillis(); - log.info("Wrote output in {}s", (end - start) / 1000); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java deleted file mode 100644 index abd15f8..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -import java.util.regex.Pattern; - -import org.apache.mahout.cf.taste.impl.common.FastIDSet; - -final class TrackData { - - private static final Pattern PIPE = Pattern.compile("\\|"); - private static final String NO_VALUE = "None"; - static final long NO_VALUE_ID = Long.MIN_VALUE; - private static final FastIDSet NO_GENRES = new FastIDSet(); - - private final long trackID; - private final long albumID; - private final long artistID; - private final FastIDSet genreIDs; - - TrackData(CharSequence line) { - String[] tokens = PIPE.split(line); - trackID = Long.parseLong(tokens[0]); - albumID = parse(tokens[1]); - artistID = parse(tokens[2]); - if (tokens.length > 3) { - genreIDs = new FastIDSet(tokens.length - 3); - for (int i = 3; i < tokens.length; i++) { - genreIDs.add(Long.parseLong(tokens[i])); - } - } else { - genreIDs = NO_GENRES; - } - } - - private static long parse(String value) { - return NO_VALUE.equals(value) ? NO_VALUE_ID : Long.parseLong(value); - } - - public long getTrackID() { - return trackID; - } - - public long getAlbumID() { - return albumID; - } - - public long getArtistID() { - return artistID; - } - - public FastIDSet getGenreIDs() { - return genreIDs; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java deleted file mode 100644 index 3012a84..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -import java.io.File; -import java.io.IOException; -import java.util.Collection; - -import org.apache.mahout.cf.taste.common.Refreshable; -import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel; -import org.apache.mahout.cf.taste.impl.common.FastByIDMap; -import org.apache.mahout.cf.taste.impl.common.FastIDSet; -import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator; -import org.apache.mahout.cf.taste.similarity.ItemSimilarity; -import org.apache.mahout.common.iterator.FileLineIterable; - -final class TrackItemSimilarity implements ItemSimilarity { - - private final FastByIDMap<TrackData> trackData; - - TrackItemSimilarity(File dataFileDirectory) throws IOException { - trackData = new FastByIDMap<>(); - for (String line : new FileLineIterable(KDDCupDataModel.getTrackFile(dataFileDirectory))) { - TrackData trackDatum = new TrackData(line); - trackData.put(trackDatum.getTrackID(), trackDatum); - } - } - - @Override - public double itemSimilarity(long itemID1, long itemID2) { - if (itemID1 == itemID2) { - return 1.0; - } - TrackData data1 = trackData.get(itemID1); - TrackData data2 = trackData.get(itemID2); - if (data1 == null || data2 == null) { - return 0.0; - } - - // Arbitrarily decide that same album means "very similar" - if (data1.getAlbumID() != TrackData.NO_VALUE_ID && data1.getAlbumID() == data2.getAlbumID()) { - return 0.9; - } - // ... and same artist means "fairly similar" - if (data1.getArtistID() != TrackData.NO_VALUE_ID && data1.getArtistID() == data2.getArtistID()) { - return 0.7; - } - - // Tanimoto coefficient similarity based on genre, but maximum value of 0.25 - FastIDSet genres1 = data1.getGenreIDs(); - FastIDSet genres2 = data2.getGenreIDs(); - if (genres1 == null || genres2 == null) { - return 0.0; - } - int intersectionSize = genres1.intersectionSize(genres2); - if (intersectionSize == 0) { - return 0.0; - } - int unionSize = genres1.size() + genres2.size() - intersectionSize; - return intersectionSize / (4.0 * unionSize); - } - - @Override - public double[] itemSimilarities(long itemID1, long[] itemID2s) { - int length = itemID2s.length; - double[] result = new double[length]; - for (int i = 0; i < length; i++) { - result[i] = itemSimilarity(itemID1, itemID2s[i]); - } - return result; - } - - @Override - public long[] allSimilarItemIDs(long itemID) { - FastIDSet allSimilarItemIDs = new FastIDSet(); - LongPrimitiveIterator allItemIDs = trackData.keySetIterator(); - while (allItemIDs.hasNext()) { - long possiblySimilarItemID = allItemIDs.nextLong(); - if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) { - allSimilarItemIDs.add(possiblySimilarItemID); - } - } - return allSimilarItemIDs.toArray(); - } - - @Override - public void refresh(Collection<Refreshable> alreadyRefreshed) { - // do nothing - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java b/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java deleted file mode 100644 index e554d10..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java +++ /dev/null @@ -1,54 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.example.kddcup.track2; - -final class UserResult { - - private final long userID; - private final byte[] resultBytes; - - UserResult(long userID, boolean[] result) { - - this.userID = userID; - - int trueCount = 0; - for (boolean b : result) { - if (b) { - trueCount++; - } - } - if (trueCount != 3) { - throw new IllegalStateException(); - } - - resultBytes = new byte[result.length]; - for (int i = 0; i < result.length; i++) { - resultBytes[i] = (byte) (result[i] ? '1' : '0'); - } - } - - public long getUserID() { - return userID; - } - - public byte[] getResultBytes() { - return resultBytes; - } - - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java b/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java deleted file mode 100644 index 22f122e..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.hadoop.example.als.netflix; - -import com.google.common.base.Preconditions; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.cf.taste.impl.model.GenericPreference; -import org.apache.mahout.cf.taste.model.Preference; -import org.apache.mahout.common.iterator.FileLineIterable; -import org.apache.mahout.common.iterator.FileLineIterator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.BufferedWriter; -import java.io.File; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.util.ArrayList; -import java.util.List; -import java.util.regex.Pattern; - -/** converts the raw files provided by netflix to an appropriate input format */ -public final class NetflixDatasetConverter { - - private static final Logger log = LoggerFactory.getLogger(NetflixDatasetConverter.class); - - private static final Pattern SEPARATOR = Pattern.compile(","); - private static final String MOVIE_DENOTER = ":"; - private static final String TAB = "\t"; - private static final String NEWLINE = "\n"; - - private NetflixDatasetConverter() { - } - - public static void main(String[] args) throws IOException { - - if (args.length != 4) { - System.err.println("Usage: NetflixDatasetConverter /path/to/training_set/ /path/to/qualifying.txt " - + "/path/to/judging.txt /path/to/destination"); - return; - } - - String trainingDataDir = args[0]; - String qualifyingTxt = args[1]; - String judgingTxt = args[2]; - Path outputPath = new Path(args[3]); - - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(outputPath.toUri(), conf); - - Preconditions.checkArgument(trainingDataDir != null, "Training Data location needs to be specified"); - log.info("Creating training set at {}/trainingSet/ratings.tsv ...", outputPath); - try (BufferedWriter writer = - new BufferedWriter( - new OutputStreamWriter( - fs.create(new Path(outputPath, "trainingSet/ratings.tsv")), Charsets.UTF_8))){ - - int ratingsProcessed = 0; - for (File movieRatings : new File(trainingDataDir).listFiles()) { - try (FileLineIterator lines = new FileLineIterator(movieRatings)) { - boolean firstLineRead = false; - String movieID = null; - while (lines.hasNext()) { - String line = lines.next(); - if (firstLineRead) { - String[] tokens = SEPARATOR.split(line); - String userID = tokens[0]; - String rating = tokens[1]; - writer.write(userID + TAB + movieID + TAB + rating + NEWLINE); - ratingsProcessed++; - if (ratingsProcessed % 1000000 == 0) { - log.info("{} ratings processed...", ratingsProcessed); - } - } else { - movieID = line.replaceAll(MOVIE_DENOTER, ""); - firstLineRead = true; - } - } - } - - } - log.info("{} ratings processed. done.", ratingsProcessed); - } - - log.info("Reading probes..."); - List<Preference> probes = new ArrayList<>(2817131); - long currentMovieID = -1; - for (String line : new FileLineIterable(new File(qualifyingTxt))) { - if (line.contains(MOVIE_DENOTER)) { - currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, "")); - } else { - long userID = Long.parseLong(SEPARATOR.split(line)[0]); - probes.add(new GenericPreference(userID, currentMovieID, 0)); - } - } - log.info("{} probes read...", probes.size()); - - log.info("Reading ratings, creating probe set at {}/probeSet/ratings.tsv ...", outputPath); - try (BufferedWriter writer = - new BufferedWriter(new OutputStreamWriter( - fs.create(new Path(outputPath, "probeSet/ratings.tsv")), Charsets.UTF_8))){ - int ratingsProcessed = 0; - for (String line : new FileLineIterable(new File(judgingTxt))) { - if (line.contains(MOVIE_DENOTER)) { - currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, "")); - } else { - float rating = Float.parseFloat(SEPARATOR.split(line)[0]); - Preference pref = probes.get(ratingsProcessed); - Preconditions.checkState(pref.getItemID() == currentMovieID); - ratingsProcessed++; - writer.write(pref.getUserID() + TAB + pref.getItemID() + TAB + rating + NEWLINE); - if (ratingsProcessed % 1000000 == 0) { - log.info("{} ratings processed...", ratingsProcessed); - } - } - } - log.info("{} ratings processed. done.", ratingsProcessed); - } - } - - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java b/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java deleted file mode 100644 index 8021d00..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.similarity.precompute.example; - -import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender; -import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity; -import org.apache.mahout.cf.taste.impl.similarity.precompute.FileSimilarItemsWriter; -import org.apache.mahout.cf.taste.impl.similarity.precompute.MultithreadedBatchItemSimilarities; -import org.apache.mahout.cf.taste.model.DataModel; -import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender; -import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities; - -import java.io.File; - -/** - * Example that precomputes all item similarities of the Movielens1M dataset - * - * Usage: download movielens1M from http://www.grouplens.org/node/73 , unzip it and invoke this code with the path - * to the ratings.dat file as argument - * - */ -public final class BatchItemSimilaritiesGroupLens { - - private BatchItemSimilaritiesGroupLens() {} - - public static void main(String[] args) throws Exception { - - if (args.length != 1) { - System.err.println("Need path to ratings.dat of the movielens1M dataset as argument!"); - System.exit(-1); - } - - File resultFile = new File(System.getProperty("java.io.tmpdir"), "similarities.csv"); - if (resultFile.exists()) { - resultFile.delete(); - } - - DataModel dataModel = new GroupLensDataModel(new File(args[0])); - ItemBasedRecommender recommender = new GenericItemBasedRecommender(dataModel, - new LogLikelihoodSimilarity(dataModel)); - BatchItemSimilarities batch = new MultithreadedBatchItemSimilarities(recommender, 5); - - int numSimilarities = batch.computeItemSimilarities(Runtime.getRuntime().availableProcessors(), 1, - new FileSimilarItemsWriter(resultFile)); - - System.out.println("Computed " + numSimilarities + " similarities for " + dataModel.getNumItems() + " items " - + "and saved them to " + resultFile.getAbsolutePath()); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java b/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java deleted file mode 100644 index 7ee9b17..0000000 --- a/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java +++ /dev/null @@ -1,96 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.cf.taste.similarity.precompute.example; - -import com.google.common.io.Files; -import com.google.common.io.InputSupplier; -import com.google.common.io.Resources; -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStreamWriter; -import java.io.Writer; -import java.net.URL; -import java.util.regex.Pattern; -import org.apache.commons.io.Charsets; -import org.apache.mahout.cf.taste.impl.model.file.FileDataModel; -import org.apache.mahout.common.iterator.FileLineIterable; - -public final class GroupLensDataModel extends FileDataModel { - - private static final String COLON_DELIMTER = "::"; - private static final Pattern COLON_DELIMITER_PATTERN = Pattern.compile(COLON_DELIMTER); - - public GroupLensDataModel() throws IOException { - this(readResourceToTempFile("/org/apache/mahout/cf/taste/example/grouplens/ratings.dat")); - } - - /** - * @param ratingsFile GroupLens ratings.dat file in its native format - * @throws IOException if an error occurs while reading or writing files - */ - public GroupLensDataModel(File ratingsFile) throws IOException { - super(convertGLFile(ratingsFile)); - } - - private static File convertGLFile(File originalFile) throws IOException { - // Now translate the file; remove commas, then convert "::" delimiter to comma - File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "ratings.txt"); - if (resultFile.exists()) { - resultFile.delete(); - } - try (Writer writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8)){ - for (String line : new FileLineIterable(originalFile, false)) { - int lastDelimiterStart = line.lastIndexOf(COLON_DELIMTER); - if (lastDelimiterStart < 0) { - throw new IOException("Unexpected input format on line: " + line); - } - String subLine = line.substring(0, lastDelimiterStart); - String convertedLine = COLON_DELIMITER_PATTERN.matcher(subLine).replaceAll(","); - writer.write(convertedLine); - writer.write('\n'); - } - } catch (IOException ioe) { - resultFile.delete(); - throw ioe; - } - return resultFile; - } - - public static File readResourceToTempFile(String resourceName) throws IOException { - InputSupplier<? extends InputStream> inSupplier; - try { - URL resourceURL = Resources.getResource(GroupLensDataModel.class, resourceName); - inSupplier = Resources.newInputStreamSupplier(resourceURL); - } catch (IllegalArgumentException iae) { - File resourceFile = new File("src/main/java" + resourceName); - inSupplier = Files.newInputStreamSupplier(resourceFile); - } - File tempFile = File.createTempFile("taste", null); - tempFile.deleteOnExit(); - Files.copy(inSupplier, tempFile); - return tempFile; - } - - @Override - public String toString() { - return "GroupLensDataModel"; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java b/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java deleted file mode 100644 index 5cec51c..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java +++ /dev/null @@ -1,128 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier; - -import com.google.common.collect.ConcurrentHashMultiset; -import com.google.common.collect.Multiset; -import com.google.common.io.Closeables; -import com.google.common.io.Files; -import org.apache.commons.io.Charsets; -import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.standard.StandardAnalyzer; -import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.math.RandomAccessSparseVector; -import org.apache.mahout.math.Vector; -import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder; -import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder; -import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder; - -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.Reader; -import java.io.StringReader; -import java.text.SimpleDateFormat; -import java.util.Collection; -import java.util.Date; -import java.util.Locale; -import java.util.Random; - -public final class NewsgroupHelper { - - private static final SimpleDateFormat[] DATE_FORMATS = { - new SimpleDateFormat("", Locale.ENGLISH), - new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH), - new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH) - }; - - public static final int FEATURES = 10000; - // 1997-01-15 00:01:00 GMT - private static final long DATE_REFERENCE = 853286460; - private static final long MONTH = 30 * 24 * 3600; - private static final long WEEK = 7 * 24 * 3600; - - private final Random rand = RandomUtils.getRandom(); - private final Analyzer analyzer = new StandardAnalyzer(); - private final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body"); - private final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept"); - - public FeatureVectorEncoder getEncoder() { - return encoder; - } - - public FeatureVectorEncoder getBias() { - return bias; - } - - public Random getRandom() { - return rand; - } - - public Vector encodeFeatureVector(File file, int actual, int leakType, Multiset<String> overallCounts) - throws IOException { - long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * rand.nextDouble())); - Multiset<String> words = ConcurrentHashMultiset.create(); - - try (BufferedReader reader = Files.newReader(file, Charsets.UTF_8)) { - String line = reader.readLine(); - Reader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date))); - countWords(analyzer, words, dateString, overallCounts); - while (line != null && !line.isEmpty()) { - boolean countHeader = ( - line.startsWith("From:") || line.startsWith("Subject:") - || line.startsWith("Keywords:") || line.startsWith("Summary:")) && leakType < 6; - do { - Reader in = new StringReader(line); - if (countHeader) { - countWords(analyzer, words, in, overallCounts); - } - line = reader.readLine(); - } while (line != null && line.startsWith(" ")); - } - if (leakType < 3) { - countWords(analyzer, words, reader, overallCounts); - } - } - - Vector v = new RandomAccessSparseVector(FEATURES); - bias.addToVector("", 1, v); - for (String word : words.elementSet()) { - encoder.addToVector(word, Math.log1p(words.count(word)), v); - } - - return v; - } - - public static void countWords(Analyzer analyzer, - Collection<String> words, - Reader in, - Multiset<String> overallCounts) throws IOException { - TokenStream ts = analyzer.tokenStream("text", in); - ts.addAttribute(CharTermAttribute.class); - ts.reset(); - while (ts.incrementToken()) { - String s = ts.getAttribute(CharTermAttribute.class).toString(); - words.add(s); - } - overallCounts.addAll(words); - ts.end(); - Closeables.close(ts, true); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java b/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java deleted file mode 100644 index 16e9d80..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.email; - -import org.apache.hadoop.io.Text; -import org.apache.hadoop.io.WritableComparable; -import org.apache.hadoop.mapreduce.Mapper; -import org.apache.mahout.math.VectorWritable; - -import java.io.IOException; -import java.util.Locale; -import java.util.regex.Pattern; - -/** - * Convert the labels created by the {@link org.apache.mahout.utils.email.MailProcessor} to one consumable - * by the classifiers - */ -public class PrepEmailMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> { - - private static final Pattern DASH_DOT = Pattern.compile("-|\\."); - private static final Pattern SLASH = Pattern.compile("\\/"); - - private boolean useListName = false; //if true, use the project name and the list name in label creation - @Override - protected void setup(Context context) throws IOException, InterruptedException { - useListName = Boolean.parseBoolean(context.getConfiguration().get(PrepEmailVectorsDriver.USE_LIST_NAME)); - } - - @Override - protected void map(WritableComparable<?> key, VectorWritable value, Context context) - throws IOException, InterruptedException { - String input = key.toString(); - ///Example: /cocoon.apache.org/dev/200307.gz/001401c3414f$8394e160$1e01a8c0@WRPO - String[] splits = SLASH.split(input); - //we need the first two splits; - if (splits.length >= 3) { - StringBuilder bldr = new StringBuilder(); - bldr.append(escape(splits[1])); - if (useListName) { - bldr.append('_').append(escape(splits[2])); - } - context.write(new Text(bldr.toString()), value); - } - - } - - private static String escape(CharSequence value) { - return DASH_DOT.matcher(value).replaceAll("_").toLowerCase(Locale.ENGLISH); - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java b/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java deleted file mode 100644 index da6e613..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.email; - -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Reducer; -import org.apache.mahout.math.VectorWritable; - -import java.io.IOException; -import java.util.Iterator; - -public class PrepEmailReducer extends Reducer<Text, VectorWritable, Text, VectorWritable> { - - private long maxItemsPerLabel = 10000; - - @Override - protected void setup(Context context) throws IOException, InterruptedException { - maxItemsPerLabel = Long.parseLong(context.getConfiguration().get(PrepEmailVectorsDriver.ITEMS_PER_CLASS)); - } - - @Override - protected void reduce(Text key, Iterable<VectorWritable> values, Context context) - throws IOException, InterruptedException { - //TODO: support randomization? Likely not needed due to the SplitInput utility which does random selection - long i = 0; - Iterator<VectorWritable> iterator = values.iterator(); - while (i < maxItemsPerLabel && iterator.hasNext()) { - context.write(key, iterator.next()); - i++; - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java b/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java deleted file mode 100644 index 8fba739..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.email; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Job; -import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; -import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; -import org.apache.hadoop.util.ToolRunner; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.HadoopUtil; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.math.VectorWritable; - -import java.util.List; -import java.util.Map; - -/** - * Convert the labels generated by {@link org.apache.mahout.text.SequenceFilesFromMailArchives} and - * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} to ones consumable by the classifiers. We do this - * here b/c if it is done in the creation of sparse vectors, the Reducer collapses all the vectors. - */ -public class PrepEmailVectorsDriver extends AbstractJob { - - public static final String ITEMS_PER_CLASS = "itemsPerClass"; - public static final String USE_LIST_NAME = "USE_LIST_NAME"; - - public static void main(String[] args) throws Exception { - ToolRunner.run(new Configuration(), new PrepEmailVectorsDriver(), args); - } - - @Override - public int run(String[] args) throws Exception { - addInputOption(); - addOutputOption(); - addOption(DefaultOptionCreator.overwriteOption().create()); - addOption("maxItemsPerLabel", "mipl", "The maximum number of items per label. Can be useful for making the " - + "training sets the same size", String.valueOf(100000)); - addOption(buildOption("useListName", "ul", "Use the name of the list as part of the label. If not set, then " - + "just use the project name", false, false, "false")); - Map<String,List<String>> parsedArgs = parseArguments(args); - if (parsedArgs == null) { - return -1; - } - - Path input = getInputPath(); - Path output = getOutputPath(); - if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { - HadoopUtil.delete(getConf(), output); - } - Job convertJob = prepareJob(input, output, SequenceFileInputFormat.class, PrepEmailMapper.class, Text.class, - VectorWritable.class, PrepEmailReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class); - convertJob.getConfiguration().set(ITEMS_PER_CLASS, getOption("maxItemsPerLabel")); - convertJob.getConfiguration().set(USE_LIST_NAME, String.valueOf(hasOption("useListName"))); - - boolean succeeded = convertJob.waitForCompletion(true); - return succeeded ? 0 : -1; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java b/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java deleted file mode 100644 index 9c0ef56..0000000 --- a/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java +++ /dev/null @@ -1,277 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.sequencelearning.hmm; - -import com.google.common.io.Resources; -import org.apache.commons.io.Charsets; -import org.apache.mahout.math.Matrix; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.net.URL; -import java.util.Arrays; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.regex.Pattern; - -/** - * This class implements a sample program that uses a pre-tagged training data - * set to train an HMM model as a POS tagger. The training data is automatically - * downloaded from the following URL: - * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then - * trains an HMM Model using supervised learning and tests the model on the - * following test data set: - * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further - * details regarding the data files can be found at - * http://flexcrfs.sourceforge.net/#Case_Study - */ -public final class PosTagger { - - private static final Logger log = LoggerFactory.getLogger(PosTagger.class); - - private static final Pattern SPACE = Pattern.compile(" "); - private static final Pattern SPACES = Pattern.compile("[ ]+"); - - /** - * No public constructors for utility classes. - */ - private PosTagger() { - // nothing to do here really. - } - - /** - * Model trained in the example. - */ - private static HmmModel taggingModel; - - /** - * Map for storing the IDs for the POS tags (hidden states) - */ - private static Map<String, Integer> tagIDs; - - /** - * Counter for the next assigned POS tag ID The value of 0 is reserved for - * "unknown POS tag" - */ - private static int nextTagId; - - /** - * Map for storing the IDs for observed words (observed states) - */ - private static Map<String, Integer> wordIDs; - - /** - * Counter for the next assigned word ID The value of 0 is reserved for - * "unknown word" - */ - private static int nextWordId = 1; // 0 is reserved for "unknown word" - - /** - * Used for storing a list of POS tags of read sentences. - */ - private static List<int[]> hiddenSequences; - - /** - * Used for storing a list of word tags of read sentences. - */ - private static List<int[]> observedSequences; - - /** - * number of read lines - */ - private static int readLines; - - /** - * Given an URL, this function fetches the data file, parses it, assigns POS - * Tag/word IDs and fills the hiddenSequences/observedSequences lists with - * data from those files. The data is expected to be in the following format - * (one word per line): word pos-tag np-tag sentences are closed with the . - * pos tag - * - * @param url Where the data file is stored - * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for - * training data, not needed for test data) - * @throws IOException in case data file cannot be read. - */ - private static void readFromURL(String url, boolean assignIDs) throws IOException { - // initialize the data structure - hiddenSequences = new LinkedList<>(); - observedSequences = new LinkedList<>(); - readLines = 0; - - // now read line by line of the input file - List<Integer> observedSequence = new LinkedList<>(); - List<Integer> hiddenSequence = new LinkedList<>(); - - for (String line :Resources.readLines(new URL(url), Charsets.UTF_8)) { - if (line.isEmpty()) { - // new sentence starts - int[] observedSequenceArray = new int[observedSequence.size()]; - int[] hiddenSequenceArray = new int[hiddenSequence.size()]; - for (int i = 0; i < observedSequence.size(); ++i) { - observedSequenceArray[i] = observedSequence.get(i); - hiddenSequenceArray[i] = hiddenSequence.get(i); - } - // now register those arrays - hiddenSequences.add(hiddenSequenceArray); - observedSequences.add(observedSequenceArray); - // and reset the linked lists - observedSequence.clear(); - hiddenSequence.clear(); - continue; - } - readLines++; - // we expect the format [word] [POS tag] [NP tag] - String[] tags = SPACE.split(line); - // when analyzing the training set, assign IDs - if (assignIDs) { - if (!wordIDs.containsKey(tags[0])) { - wordIDs.put(tags[0], nextWordId++); - } - if (!tagIDs.containsKey(tags[1])) { - tagIDs.put(tags[1], nextTagId++); - } - } - // determine the IDs - Integer wordID = wordIDs.get(tags[0]); - Integer tagID = tagIDs.get(tags[1]); - // now construct the current sequence - if (wordID == null) { - observedSequence.add(0); - } else { - observedSequence.add(wordID); - } - - if (tagID == null) { - hiddenSequence.add(0); - } else { - hiddenSequence.add(tagID); - } - } - - // if there is still something in the pipe, register it - if (!observedSequence.isEmpty()) { - int[] observedSequenceArray = new int[observedSequence.size()]; - int[] hiddenSequenceArray = new int[hiddenSequence.size()]; - for (int i = 0; i < observedSequence.size(); ++i) { - observedSequenceArray[i] = observedSequence.get(i); - hiddenSequenceArray[i] = hiddenSequence.get(i); - } - // now register those arrays - hiddenSequences.add(hiddenSequenceArray); - observedSequences.add(observedSequenceArray); - } - } - - private static void trainModel(String trainingURL) throws IOException { - tagIDs = new HashMap<>(44); // we expect 44 distinct tags - wordIDs = new HashMap<>(19122); // we expect 19122 - // distinct words - log.info("Reading and parsing training data file from URL: {}", trainingURL); - long start = System.currentTimeMillis(); - readFromURL(trainingURL, true); - long end = System.currentTimeMillis(); - double duration = (end - start) / 1000.0; - log.info("Parsing done in {} seconds!", duration); - log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.", - readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1); - start = System.currentTimeMillis(); - taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId, - hiddenSequences, observedSequences, 0.05); - // we have to adjust the model a bit, - // since we assume a higher probability that a given unknown word is NNP - // than anything else - Matrix emissions = taggingModel.getEmissionMatrix(); - for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) { - emissions.setQuick(i, 0, 0.1 / taggingModel.getNrOfHiddenStates()); - } - int nnptag = tagIDs.get("NNP"); - emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates()); - // re-normalize the emission probabilities - HmmUtils.normalizeModel(taggingModel); - // now register the names - taggingModel.registerHiddenStateNames(tagIDs); - taggingModel.registerOutputStateNames(wordIDs); - end = System.currentTimeMillis(); - duration = (end - start) / 1000.0; - log.info("Trained HMM models in {} seconds!", duration); - } - - private static void testModel(String testingURL) throws IOException { - log.info("Reading and parsing test data file from URL: {}", testingURL); - long start = System.currentTimeMillis(); - readFromURL(testingURL, false); - long end = System.currentTimeMillis(); - double duration = (end - start) / 1000.0; - log.info("Parsing done in {} seconds!", duration); - log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size()); - - start = System.currentTimeMillis(); - int errorCount = 0; - int totalCount = 0; - for (int i = 0; i < observedSequences.size(); ++i) { - // fetch the viterbi path as the POS tag for this observed sequence - int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false); - // compare with the expected - int[] posExpected = hiddenSequences.get(i); - for (int j = 0; j < posExpected.length; ++j) { - totalCount++; - if (posEstimate[j] != posExpected[j]) { - errorCount++; - } - } - } - end = System.currentTimeMillis(); - duration = (end - start) / 1000.0; - log.info("POS tagged test file in {} seconds!", duration); - double errorRate = (double) errorCount / totalCount; - log.info("Tagged the test file with an error rate of: {}", errorRate); - } - - private static List<String> tagSentence(String sentence) { - // first, we need to isolate all punctuation characters, so that they - // can be recognized - sentence = sentence.replaceAll("[,.!?:;\"]", " $0 "); - sentence = sentence.replaceAll("''", " '' "); - // now we tokenize the sentence - String[] tokens = SPACES.split(sentence); - // now generate the observed sequence - int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0); - // POS tag this observedSequence - int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false); - // and now decode the tag names - return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null); - } - - public static void main(String[] args) throws IOException { - // generate the model from URL - trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt"); - testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt"); - // tag an exemplary sentence - String test = "McDonalds is a huge company with many employees ."; - String[] testWords = SPACE.split(test); - List<String> posTags = tagSentence(test); - for (int i = 0; i < posTags.size(); ++i) { - log.info("{}[{}]", testWords[i], posTags.get(i)); - } - } - -}
