Author: ssc
Date: Thu Apr 7 21:10:28 2011
New Revision: 1090013
URL: http://svn.apache.org/viewvc?rev=1090013&view=rev
Log:
MAHOUT-657 Sample code to apply SVD to the KDD data
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java?rev=1090013&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/EstimateConverter.java
Thu Apr 7 21:10:28 2011
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class EstimateConverter {
+
+ private static final Logger log =
LoggerFactory.getLogger(EstimateConverter.class);
+
+ private EstimateConverter() {}
+
+ public static byte convert(double estimate, long userID, long itemID) {
+ if (Double.isNaN(estimate)) {
+ log.warn("Unable to compute estimate for user {}, item {}", userID,
itemID);
+ return 0x7F;
+ } else {
+ int scaledEstimate = (int) (estimate * 2.55);
+ if (scaledEstimate > 255) {
+ scaledEstimate = 255;
+ } else if (scaledEstimate < 0) {
+ scaledEstimate = 0;
+ }
+ return (byte) scaledEstimate;
+ }
+ }
+}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java?rev=1090013&r1=1090012&r2=1090013&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/Track1Callable.java
Thu Apr 7 21:10:28 2011
@@ -54,19 +54,7 @@ final class Track1Callable implements Ca
log.warn("Unknown item {}; OK unless this is the real contest data",
itemID);
continue;
}
-
- if (Double.isNaN(estimate)) {
- log.warn("Unable to compute estimate for user {}, item {}", userID,
itemID);
- result[i] = 0x7F;
- } else {
- int scaledEstimate = (int) (estimate * 2.55);
- if (scaledEstimate > 255) {
- scaledEstimate = 255;
- } else if (scaledEstimate < 0) {
- scaledEstimate = 0;
- }
- result[i] = (byte) scaledEstimate;
- }
+ result[i] = EstimateConverter.convert(estimate, userID, itemID);
}
if (COUNT.incrementAndGet() % 10000 == 0) {
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java?rev=1090013&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/DataModelFactorizablePreferences.java
Thu Apr 7 21:10:28 2011
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * can be used to drop {@link DataModel}s into {@link
ParallelArraysSGDFactorizer}
+ */
+public class DataModelFactorizablePreferences implements
FactorizablePreferences {
+
+ private final FastIDSet userIDs;
+ private final FastIDSet itemIDs;
+
+ private final List<Preference> preferences;
+
+ private final float minPreference;
+ private final float maxPreference;
+
+ public DataModelFactorizablePreferences(DataModel dataModel) {
+
+ minPreference = dataModel.getMinPreference();
+ maxPreference = dataModel.getMaxPreference();
+
+ try {
+ userIDs = new FastIDSet(dataModel.getNumUsers());
+ itemIDs = new FastIDSet(dataModel.getNumItems());
+ preferences = new ArrayList<Preference>();
+
+ LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
+ while (userIDsIterator.hasNext()) {
+ long userID = userIDsIterator.nextLong();
+ userIDs.add(userID);
+ for (Preference preference : dataModel.getPreferencesFromUser(userID))
{
+ itemIDs.add(preference.getItemID());
+ preferences.add(new GenericPreference(userID,
preference.getItemID(), preference.getValue()));
+ }
+ }
+ } catch (Exception e) {
+ throw new IllegalStateException("Unable to create factorizable
preferences!", e);
+ }
+ }
+
+ @Override
+ public LongPrimitiveIterator getUserIDs() {
+ return userIDs.iterator();
+ }
+
+ @Override
+ public LongPrimitiveIterator getItemIDs() {
+ return itemIDs.iterator();
+ }
+
+ @Override
+ public Iterable<Preference> getPreferences() {
+ return preferences;
+ }
+
+ @Override
+ public float getMinPreference() {
+ return minPreference;
+ }
+
+ @Override
+ public float getMaxPreference() {
+ return maxPreference;
+ }
+
+ @Override
+ public int numUsers() {
+ return userIDs.size();
+ }
+
+ @Override
+ public int numItems() {
+ return itemIDs.size();
+ }
+
+ @Override
+ public int numPreferences() {
+ return preferences.size();
+ }
+}
+
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java?rev=1090013&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/FactorizablePreferences.java
Thu Apr 7 21:10:28 2011
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.Preference;
+
+/**
+ * models the necessary input for {@link ParallelArraysSGDFactorizer}
+ */
+public interface FactorizablePreferences {
+
+ LongPrimitiveIterator getUserIDs();
+
+ LongPrimitiveIterator getItemIDs();
+
+ Iterable<Preference> getPreferences();
+
+ float getMinPreference();
+
+ float getMaxPreference();
+
+ int numUsers();
+
+ int numItems();
+
+ int numPreferences();
+
+}
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java?rev=1090013&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/KDDCupFactorizablePreferences.java
Thu Apr 7 21:10:28 2011
@@ -0,0 +1,159 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.Preference;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Iterator;
+
+public class KDDCupFactorizablePreferences implements FactorizablePreferences {
+
+ private final File dataFile;
+
+ public KDDCupFactorizablePreferences(File dataFile) {
+ this.dataFile = dataFile;
+ }
+
+ @Override
+ public LongPrimitiveIterator getUserIDs() {
+ return new FixedSizeLongIterator(numUsers());
+ }
+
+ @Override
+ public LongPrimitiveIterator getItemIDs() {
+ return new FixedSizeLongIterator(numItems());
+ }
+
+ @Override
+ public Iterable<Preference> getPreferences() {
+ return new Iterable<Preference>() {
+ @Override
+ public Iterator<Preference> iterator() {
+ try {
+ return new DataFilePreferencesIterator(new
DataFileIterator(dataFile));
+ } catch (IOException e) {
+ throw new IllegalStateException("Cannot iterate over datafile!", e);
+ }
+ }
+ };
+ }
+
+ @Override
+ public float getMinPreference() {
+ return 0;
+ }
+
+ @Override
+ public float getMaxPreference() {
+ return 100;
+ }
+
+ @Override
+ public int numUsers() {
+ return 1000990;
+ }
+
+ @Override
+ public int numItems() {
+ return 624961;
+ }
+
+ @Override
+ public int numPreferences() {
+ return 252800275;
+ }
+
+ static class DataFilePreferencesIterator implements Iterator<Preference> {
+
+ private final DataFileIterator dataFileIterator;
+
+ Iterator<Preference> currentUserPrefsIterator;
+
+ public DataFilePreferencesIterator(DataFileIterator dataFileIterator) {
+ this.dataFileIterator = dataFileIterator;
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (currentUserPrefsIterator != null &&
currentUserPrefsIterator.hasNext()) {
+ return true;
+ } else {
+ return dataFileIterator.hasNext();
+ }
+ }
+
+ @Override
+ public Preference next() {
+ if (currentUserPrefsIterator == null ||
!currentUserPrefsIterator.hasNext()) {
+ currentUserPrefsIterator =
dataFileIterator.next().getFirst().iterator();
+ }
+ return currentUserPrefsIterator.next();
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ static class FixedSizeLongIterator implements LongPrimitiveIterator {
+
+ private long currentValue;
+ private final long maximum;
+
+ public FixedSizeLongIterator(long maximum) {
+ this.maximum = maximum;
+ currentValue = 0;
+ }
+
+ @Override
+ public long nextLong() {
+ return currentValue++;
+ }
+
+ @Override
+ public long peek() {
+ return currentValue;
+ }
+
+ @Override
+ public void skip(int n) {
+ currentValue += n;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return currentValue < maximum;
+ }
+
+ @Override
+ public Long next() {
+ return ++currentValue;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+}
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java?rev=1090013&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
Thu Apr 7 21:10:28 2011
@@ -0,0 +1,257 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Random;
+
+/**
+ * {@link org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer} based on
Simon Funk's famous article "Netflix Update: Try this at home"
+ * {@see http://sifter.org/~simon/journal/20061211.html}.
+ *
+ * Attempts to be as memory efficient as possible, only iterating once through
the {@link FactorizablePreferences} or {@link DataModel} while
+ * copying everything to primitive arrays. Learning works in place on these
datastructures after that.
+ *
+ */
+public class ParallelArraysSGDFactorizer implements Factorizer {
+
+ public static final double DEFAULT_LEARNING_RATE = 0.005;
+ public static final double DEFAULT_PREVENT_OVERFITTING = 0.02;
+ public static final double DEFAULT_RANDOM_NOISE = 0.005;
+
+ private final int numFeatures;
+ private final int numIterations;
+ private final float minPreference;
+ private final float maxPreference;
+
+ private final Random random;
+ private final double learningRate;
+ private final double preventOverfitting;
+
+ private final FastByIDMap<Integer> userIDMapping;
+ private final FastByIDMap<Integer> itemIDMapping;
+
+ private final double[][] userFeatures;
+ private final double[][] itemFeatures;
+
+ private final int[] userIndexes;
+ private final int[] itemIndexes;
+ private final float[] values;
+
+ private final double defaultValue;
+ private final double interval;
+ private final double[] cachedEstimates;
+
+
+ private static final Logger log =
LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);
+
+ public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int
numIterations) {
+ this(new DataModelFactorizablePreferences(dataModel), numFeatures,
numIterations, DEFAULT_LEARNING_RATE,
+ DEFAULT_PREVENT_OVERFITTING, DEFAULT_RANDOM_NOISE);
+ }
+
+ public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int
numIterations, double learningRate,
+ double preventOverfitting, double
randomNoise) {
+ this(new DataModelFactorizablePreferences(dataModel), numFeatures,
numIterations, learningRate, preventOverfitting,
+ randomNoise);
+ }
+
+ public ParallelArraysSGDFactorizer(FactorizablePreferences
factorizablePrefs, int numFeatures, int numIterations) {
+ this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE,
DEFAULT_PREVENT_OVERFITTING,
+ DEFAULT_RANDOM_NOISE);
+ }
+
+ public ParallelArraysSGDFactorizer(FactorizablePreferences
factorizablePreferences, int numFeatures,
+ int numIterations, double learningRate, double preventOverfitting,
double randomNoise) {
+
+ this.numFeatures = numFeatures;
+ this.numIterations = numIterations;
+ minPreference = factorizablePreferences.getMinPreference();
+ maxPreference = factorizablePreferences.getMaxPreference();
+
+ this.random = RandomUtils.getRandom();
+ this.learningRate = learningRate;
+ this.preventOverfitting = preventOverfitting;
+
+ int numUsers = factorizablePreferences.numUsers();
+ int numItems = factorizablePreferences.numItems();
+ int numPrefs = factorizablePreferences.numPreferences();
+
+ log.info("Mapping {} users...", numUsers);
+ userIDMapping = new FastByIDMap<Integer>(numUsers);
+ int index = 0;
+ LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs();
+ while (userIterator.hasNext()) {
+ userIDMapping.put(userIterator.nextLong(), index++);
+ }
+
+ log.info("Mapping {} items", numItems);
+ itemIDMapping = new FastByIDMap<Integer>(numItems);
+ index = 0;
+ LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs();
+ while (itemIterator.hasNext()) {
+ itemIDMapping.put(itemIterator.nextLong(), index++);
+ }
+
+ this.userIndexes = new int[numPrefs];
+ this.itemIndexes = new int[numPrefs];
+ this.values = new float[numPrefs];
+ this.cachedEstimates = new double[numPrefs];
+
+ index = 0;
+ log.info("Loading {} preferences into memory", numPrefs);
+ RunningAverage average = new FullRunningAverage();
+ for (Preference preference : factorizablePreferences.getPreferences()) {
+ userIndexes[index] = userIDMapping.get(preference.getUserID());
+ itemIndexes[index] = itemIDMapping.get(preference.getItemID());
+ values[index] = preference.getValue();
+ cachedEstimates[index] = 0;
+
+ average.addDatum(preference.getValue());
+
+ index++;
+ if (index % 1000000 == 0) {
+ log.info("Processed {} preferences", index);
+ }
+ }
+ log.info("Processed {} preferences, done.", index);
+
+ double averagePreference = average.getAverage();
+ log.info("Average preference value is {}", averagePreference);
+
+ double prefInterval = factorizablePreferences.getMaxPreference() -
factorizablePreferences.getMinPreference();
+ defaultValue = Math.sqrt((averagePreference - (prefInterval * 0.1)) /
numFeatures);
+ interval = (prefInterval * 0.1) / numFeatures;
+
+ userFeatures = new double[numUsers][numFeatures];
+ itemFeatures = new double[numItems][numFeatures];
+
+ log.info("Initializing feature vectors...");
+ for (int feature = 0; feature < numFeatures; feature++) {
+ for (int userIndex = 0; userIndex < numUsers; userIndex++) {
+ userFeatures[userIndex][feature] = defaultValue + (random.nextDouble()
- 0.5) * interval * randomNoise;
+ }
+ for (int itemIndex = 0; itemIndex < numItems; itemIndex++) {
+ itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble()
- 0.5) * interval * randomNoise;
+ }
+ }
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ for (int feature = 0; feature < numFeatures; feature++) {
+ log.info("Shuffling preferences...");
+ shufflePreferences();
+ log.info("Starting training of feature {} ...", feature);
+ for (int currentIteration = 0; currentIteration < numIterations;
currentIteration++) {
+ if (currentIteration != (numIterations - 1)) {
+ trainingIteration(feature);
+ } else {
+ double rmse = trainingIterationWithRmse(feature);
+ log.info("Finished training feature {} with RMSE {}", feature, rmse);
+ }
+ }
+ if (feature < numFeatures - 1) {
+ log.info("Updating cache...");
+ for (int index = 0; index < userIndexes.length; index++) {
+ cachedEstimates[index] = estimate(userIndexes[index],
itemIndexes[index], feature, cachedEstimates[index],
+ false);
+ }
+ }
+ }
+ log.info("Factorization done");
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures,
itemFeatures);
+ }
+
+ private void trainingIteration(int feature) {
+ for (int index = 0; index < userIndexes.length; index++) {
+ train(userIndexes[index], itemIndexes[index], feature, values[index],
cachedEstimates[index]);
+ }
+ }
+
+ private double trainingIterationWithRmse(int feature) {
+ double rmse = 0;
+ for (int index = 0; index < userIndexes.length; index++) {
+ double error = train(userIndexes[index], itemIndexes[index], feature,
values[index], cachedEstimates[index]);
+ rmse += (error * error);
+ }
+ return Math.sqrt(rmse / (double) userIndexes.length);
+ }
+
+ private double estimate(int userIndex, int itemIndex, int feature, double
cachedEstimate, boolean trailing) {
+ double sum = cachedEstimate;
+ sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature];
+ if (trailing) {
+ sum += (numFeatures - feature - 1) * ((defaultValue + interval) *
(defaultValue + interval));
+ if (sum > maxPreference) {
+ sum = maxPreference;
+ } else if (sum < minPreference) {
+ sum = minPreference;
+ }
+ }
+ return sum;
+ }
+
+ public double train(int userIndex, int itemIndex, int feature, double
original, double cachedEstimate) {
+ double error = original - estimate(userIndex, itemIndex, feature,
cachedEstimate, true);
+ double[] userVector = userFeatures[userIndex];
+ double[] itemVector = itemFeatures[itemIndex];
+
+ userVector[feature] += learningRate * (error * itemVector[feature] -
preventOverfitting * userVector[feature]);
+ itemVector[feature] += learningRate * (error * userVector[feature] -
preventOverfitting * itemVector[feature]);
+
+ return error;
+ }
+
+ protected void shufflePreferences() {
+ /* Durstenfeld shuffle */
+ for (int currentPos = userIndexes.length - 1; currentPos > 0;
currentPos--) {
+ int swapPos = random.nextInt(currentPos + 1);
+ swapPreferences(currentPos, swapPos);
+ }
+ }
+
+ private void swapPreferences(int posA, int posB) {
+ int tmpUserIndex = userIndexes[posA];
+ int tmpItemIndex = itemIndexes[posA];
+ float tmpValue = values[posA];
+ double tmpEstimate = cachedEstimates[posA];
+
+ userIndexes[posA] = userIndexes[posB];
+ itemIndexes[posA] = itemIndexes[posB];
+ values[posA] = values[posB];
+ cachedEstimates[posA] = cachedEstimates[posB];
+
+ userIndexes[posB] = tmpUserIndex;
+ itemIndexes[posB] = tmpItemIndex;
+ values[posB] = tmpValue;
+ cachedEstimates[posB] = tmpEstimate;
+ }
+}
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java?rev=1090013&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
Thu Apr 7 21:10:28 2011
@@ -0,0 +1,121 @@
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+
+/**
+ * run an SVD factorization of the KDD track1 data.
+ *
+ * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M
+ *
+ */
+public class Track1SVDRunner {
+
+ private static final Logger log =
LoggerFactory.getLogger(Track1SVDRunner.class);
+
+ public static void main(String[] args) throws Exception {
+
+ if (args.length != 2) {
+ System.err.println("Necessary arguments: <kddDataFileDirectory>
<resultFile>");
+ System.exit(-1);
+ }
+
+ File dataFileDirectory = new File(args[0]);
+ if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+ throw new IllegalArgumentException("Bad data file directory: " +
dataFileDirectory);
+ }
+
+ File resultFile = new File(args[1]);
+
+ /* the knobs to turn */
+ int numFeatures = 20;
+ int numIterations = 5;
+ double learningRate = 0.0001;
+ double preventOverfitting = 0.002;
+ double randomNoise = 0.0001;
+
+
+ KDDCupFactorizablePreferences factorizablePreferences =
+ new
KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+
+ Factorizer sgdFactorizer = new
ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations,
+ learningRate, preventOverfitting, randomNoise);
+
+ Factorization factorization = sgdFactorizer.factorize();
+
+ log.info("Estimating validation preferences...");
+ int prefsProcessed = 0;
+ RunningAverage average = new FullRunningAverage();
+ DataFileIterable validations = new
DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory));
+ for (Pair<PreferenceArray,long[]> validationPair : validations) {
+ for (Preference validationPref : validationPair.getFirst()) {
+ double estimate = estimatePreference(factorization,
validationPref.getUserID(), validationPref.getItemID(),
+ factorizablePreferences.getMinPreference(),
factorizablePreferences.getMaxPreference());
+ double error = validationPref.getValue() - estimate;
+ average.addDatum(error * error);
+ prefsProcessed++;
+ if (prefsProcessed % 100000 == 0) {
+ log.info("Computed {} estimations", prefsProcessed);
+ }
+ }
+ }
+ log.info("Computed {} estimations, done.", prefsProcessed);
+
+ double rmse = Math.sqrt(average.getAverage());
+ log.info("RMSE {}", rmse);
+
+ log.info("Estimating test preferences...");
+ OutputStream out = null;
+ try {
+ out = new BufferedOutputStream(new FileOutputStream(resultFile));
+
+ DataFileIterable tests = new
DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory));
+ for (Pair<PreferenceArray,long[]> testPair : tests) {
+ for (Preference testPref : testPair.getFirst()) {
+ double estimate = estimatePreference(factorization,
testPref.getUserID(), testPref.getItemID(),
+ factorizablePreferences.getMinPreference(),
factorizablePreferences.getMaxPreference());
+ byte result = EstimateConverter.convert(estimate,
testPref.getUserID(), testPref.getItemID());
+ out.write(result);
+ }
+ }
+ } finally {
+ out.flush();
+ out.close();
+ }
+ log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath());
+ }
+
+ static double estimatePreference(Factorization factorization, long userID,
long itemID, float minPreference,
+ float maxPreference) throws NoSuchUserException, NoSuchItemException {
+ double[] userFeatures = factorization.getUserFeatures(userID);
+ double[] itemFeatures = factorization.getItemFeatures(itemID);
+ double estimate = 0;
+ for (int feature = 0; feature < userFeatures.length; feature++) {
+ estimate += userFeatures[feature] * itemFeatures[feature];
+ }
+ if (estimate < minPreference) {
+ estimate = minPreference;
+ } else if (estimate > maxPreference) {
+ estimate = maxPreference;
+ }
+ return estimate;
+ }
+
+}