Author: srowen Date: Mon Mar 29 10:59:47 2010 New Revision: 928711 URL: http://svn.apache.org/viewvc?rev=928711&view=rev Log: Add standard deviation support to JDBC diff storage
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java - copied, changed from r928681, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java (from r928681, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java) URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java&r1=928681&r2=928711&rev=928711&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java Mon Mar 29 10:59:47 2010 @@ -21,80 +21,58 @@ import java.io.Serializable; /** * <p> - * A simple class that can keep track of a running avearage of a series of numbers. One can add to or remove - * from the series, as well as update a datum in the series. The class does not actually keep track of the - * series of values, just its running average, so it doesn't even matter if you remove/change a value that - * wasn't added. + * A simple class that represents a fixed value of an average and count. This is useful + * when an API needs to return {...@link RunningAverage} but is not in a position to accept + * updates to it. * </p> */ -public class FullRunningAverage implements RunningAverage, Serializable { - - private int count; - private double average; - - public FullRunningAverage() { - count = 0; - average = Double.NaN; +public class FixedRunningAverage implements RunningAverage, Serializable { + + private final double average; + private final int count; + + public FixedRunningAverage(double average, int count) { + this.average = average; + this.count = count; } - + /** - * @param datum - * new item to add to the running average + * @throws UnsupportedOperationException */ @Override public synchronized void addDatum(double datum) { - if (++count == 1) { - average = datum; - } else { - average = average * (count - 1) / count + datum / count; - } + throw new UnsupportedOperationException(); } - + /** - * @param datum - * item to remove to the running average - * @throws IllegalStateException - * if count is 0 + * @throws UnsupportedOperationException */ @Override public synchronized void removeDatum(double datum) { - if (count == 0) { - throw new IllegalStateException(); - } - if (--count == 0) { - average = Double.NaN; - } else { - average = average * (count + 1) / count - datum / count; - } + throw new UnsupportedOperationException(); } - + /** - * @param delta - * amount by which to change a datum in the running average - * @throws IllegalStateException - * if count is 0 + * @throws UnsupportedOperationException */ @Override public synchronized void changeDatum(double delta) { - if (count == 0) { - throw new IllegalStateException(); - } - average += delta / count; + throw new UnsupportedOperationException(); } - + @Override public synchronized int getCount() { return count; } - + @Override public synchronized double getAverage() { return average; } - + @Override public synchronized String toString() { return String.valueOf(average); } - -} + +} \ No newline at end of file Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java?rev=928711&view=auto ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java (added) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java Mon Mar 29 10:59:47 2010 @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.cf.taste.impl.common; + +/** + * <p> + * A simple class that represents a fixed value of an average, count and standard deviation. This is useful + * when an API needs to return {...@link RunningAverageAndStdDev} but is not in a position to accept + * updates to it. + * </p> + */ +public final class FixedRunningAverageAndStdDev extends FixedRunningAverage implements RunningAverageAndStdDev { + + private final double stdDev; + + public FixedRunningAverageAndStdDev(double average, double stdDev, int count) { + super(average, count); + this.stdDev = stdDev; + } + + @Override + public synchronized String toString() { + return super.toString() + ',' + stdDev; + } + + @Override + public double getStandardDeviation() { + return stdDev; + } + +} \ No newline at end of file Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java?rev=928711&r1=928710&r2=928711&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java Mon Mar 29 10:59:47 2010 @@ -30,6 +30,8 @@ import javax.sql.DataSource; import org.apache.mahout.cf.taste.common.Refreshable; import org.apache.mahout.cf.taste.common.TasteException; import org.apache.mahout.cf.taste.impl.common.FastIDSet; +import org.apache.mahout.cf.taste.impl.common.FixedRunningAverage; +import org.apache.mahout.cf.taste.impl.common.FixedRunningAverageAndStdDev; import org.apache.mahout.cf.taste.impl.common.RefreshHelper; import org.apache.mahout.cf.taste.impl.common.RunningAverage; import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent; @@ -45,7 +47,7 @@ import org.slf4j.LoggerFactory; * A {...@link DiffStorage} which stores diffs in a database. Database-specific implementations subclass this * abstract class. Note that this implementation has a fairly particular dependence on the * {...@link org.apache.mahout.cf.taste.model.DataModel} used; it needs a {...@link JDBCDataModel} attached to the - * same database since its efficent operation depends on accessing preference data in the database directly. + * same database since its efficient operation depends on accessing preference data in the database directly. * </p> */ public abstract class AbstractJDBCDiffStorage extends AbstractJDBCComponent implements DiffStorage { @@ -57,7 +59,8 @@ public abstract class AbstractJDBCDiffSt public static final String DEFAULT_ITEM_B_COLUMN = "item_id_b"; public static final String DEFAULT_COUNT_COLUMN = "count"; public static final String DEFAULT_AVERAGE_DIFF_COLUMN = "average_diff"; - + public static final String DEFAULT_STDEV_COLUMN = "standard_deviation"; + private final DataSource dataSource; private final String getDiffSQL; private final String getDiffsSQL; @@ -140,7 +143,7 @@ public abstract class AbstractJDBCDiffSt stmt.setLong(4, itemID1); log.debug("Executing SQL query: {}", getDiffSQL); rs = stmt.executeQuery(); - return rs.next() ? new FixedRunningAverage(rs.getInt(1), rs.getDouble(2)) : null; + return rs.next() ? new FixedRunningAverageAndStdDev(rs.getDouble(2), rs.getDouble(3), rs.getInt(1)) : null; } catch (SQLException sqle) { log.warn("Exception while retrieving diff", sqle); throw new TasteException(sqle); @@ -175,7 +178,7 @@ public abstract class AbstractJDBCDiffSt i++; // result[i] is null for these values of i } - result[i] = new FixedRunningAverage(rs.getInt(1), rs.getDouble(2)); + result[i] = new FixedRunningAverageAndStdDev(rs.getDouble(2), rs.getDouble(3), rs.getInt(1)); i++; } } catch (SQLException sqle) { @@ -204,7 +207,7 @@ public abstract class AbstractJDBCDiffSt if (rs.next()) { int count = rs.getInt(1); if (count > 0) { - return new FixedRunningAverage(count, rs.getDouble(2)); + return new FixedRunningAverage(rs.getDouble(2), count); } } return null; @@ -215,7 +218,12 @@ public abstract class AbstractJDBCDiffSt IOUtils.quietClose(rs, stmt, conn); } } - + + /** + * Note that this implementation does <em>not</em> update standard deviations. This would + * be expensive relative to the value of slightly adjusting these values, which are merely + * used as weighted. Rebuilding the diffs table will update standard deviations. + */ @Override public void updateItemPref(long itemID, float prefDelta, boolean remove) throws TasteException { Connection conn = null; @@ -330,41 +338,4 @@ public abstract class AbstractJDBCDiffSt public void refresh(Collection<Refreshable> alreadyRefreshed) { refreshHelper.refresh(alreadyRefreshed); } - - private static class FixedRunningAverage implements RunningAverage { - - private final int count; - private final double average; - - private FixedRunningAverage(int count, double average) { - this.count = count; - this.average = average; - } - - @Override - public void addDatum(double datum) { - throw new UnsupportedOperationException(); - } - - @Override - public void removeDatum(double datum) { - throw new UnsupportedOperationException(); - } - - @Override - public void changeDatum(double delta) { - throw new UnsupportedOperationException(); - } - - @Override - public int getCount() { - return count; - } - - @Override - public double getAverage() { - return average; - } - } - } \ No newline at end of file Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java?rev=928711&r1=928710&r2=928711&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java Mon Mar 29 10:59:47 2010 @@ -34,32 +34,36 @@ import org.apache.mahout.cf.taste.impl.m * <th>item_id_a</th> * <th>item_id_b</th> * <th>average_diff</th> + * <th>standard_deviation</th> * <th>count</th> * </tr> * <tr> * <td>123</td> * <td>234</td> * <td>0.5</td> + * <td>0.12</td> * <td>5</td> * </tr> * <tr> * <td>123</td> * <td>789</td> * <td>-1.33</td> + * <td>0.2</td> * <td>3</td> * </tr> * <tr> * <td>234</td> * <td>789</td> * <td>2.1</td> + * <td>1.03</td> * <td>1</td> * </tr> * </table> * * <p> * <code>item_id_a</code> and <code>item_id_b</code> should have types compatible with the long primitive - * type. <code>average_diff</code> must be compatible with <code>float</code> and <code>count</code> must be - * compatible with <code>int</code>. + * type. <code>average_diff</code> and <code>standard_deviation</code> must be compatible with + * <code>float</code> and <code>count</code> must be compatible with <code>int</code>. * </p> * * <p> @@ -73,6 +77,7 @@ import org.apache.mahout.cf.taste.impl.m * item_id_a BIGINT NOT NULL, * item_id_b BIGINT NOT NULL, * average_diff FLOAT NOT NULL, + * standard_deviation FLOAT NOT NULL, * count INT NOT NULL, * PRIMARY KEY (item_id_a, item_id_b), * INDEX (item_id_a), @@ -87,8 +92,14 @@ public final class MySQLJDBCDiffStorage private static final int DEFAULT_MIN_DIFF_COUNT = 2; public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel) throws TasteException { - this(dataModel, DEFAULT_DIFF_TABLE, DEFAULT_ITEM_A_COLUMN, DEFAULT_ITEM_B_COLUMN, DEFAULT_COUNT_COLUMN, - DEFAULT_AVERAGE_DIFF_COLUMN, DEFAULT_MIN_DIFF_COUNT); + this(dataModel, + DEFAULT_DIFF_TABLE, + DEFAULT_ITEM_A_COLUMN, + DEFAULT_ITEM_B_COLUMN, + DEFAULT_COUNT_COLUMN, + DEFAULT_AVERAGE_DIFF_COLUMN, + DEFAULT_STDEV_COLUMN, + DEFAULT_MIN_DIFF_COUNT); } public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel, @@ -97,14 +108,17 @@ public final class MySQLJDBCDiffStorage String itemIDBColumn, String countColumn, String avgColumn, + String stdevColumn, int minDiffCount) throws TasteException { super(dataModel, - // getDiffSQL - "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable + " WHERE " + itemIDAColumn + // getDiffSQL + "SELECT " + countColumn + ", " + avgColumn + ", " + stdevColumn + " FROM " + + diffsTable + " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=? UNION " + "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable + " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=?", // getDiffsSQL - "SELECT " + countColumn + ", " + avgColumn + ", " + itemIDAColumn + " FROM " + diffsTable + ", " + "SELECT " + countColumn + ", " + avgColumn + ", " + stdevColumn + ", " + itemIDAColumn + + " FROM " + diffsTable + ", " + dataModel.getPreferenceTable() + " WHERE " + itemIDBColumn + "=? AND " + itemIDAColumn + " = " + dataModel.getItemIDColumn() + " AND " + dataModel.getUserIDColumn() + "=? ORDER BY " + itemIDAColumn, @@ -139,17 +153,20 @@ public final class MySQLJDBCDiffStorage // deleteDiffsSQL "TRUNCATE " + diffsTable, // createDiffsSQL - "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn + ", " - + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn() + ", prefsB." - + dataModel.getItemIDColumn() + ',' + " AVG(prefsB." + dataModel.getPreferenceColumn() - + " - prefsA." + dataModel.getPreferenceColumn() + ")," + " COUNT(1) AS count FROM " - + dataModel.getPreferenceTable() + " prefsA, " + dataModel.getPreferenceTable() - + " prefsB WHERE prefsA." + dataModel.getUserIDColumn() + " = prefsB." - + dataModel.getUserIDColumn() + " AND prefsA." + dataModel.getItemIDColumn() + " < prefsB." - + dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA." + dataModel.getItemIDColumn() - + ", prefsB." + dataModel.getItemIDColumn() + " HAVING count >=?", + "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn + + ", " + stdevColumn + ", " + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn() + + ", prefsB." + dataModel.getItemIDColumn() + ", AVG(prefsB." + dataModel.getPreferenceColumn() + + " - prefsA." + dataModel.getPreferenceColumn() + "), STDDEV_POP(prefsB." + + dataModel.getPreferenceColumn() + " - prefsA." + dataModel.getPreferenceColumn() + + "), COUNT(1) AS count FROM " + dataModel.getPreferenceTable() + " prefsA, " + + dataModel.getPreferenceTable() + " prefsB WHERE prefsA." + dataModel.getUserIDColumn() + + " = prefsB." + dataModel.getUserIDColumn() + " AND prefsA." + dataModel.getItemIDColumn() + + " < prefsB." + dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA." + + dataModel.getItemIDColumn() + ", prefsB." + dataModel.getItemIDColumn() + + " HAVING count >= ?", // diffsExistSQL - "SELECT COUNT(1) FROM " + diffsTable, minDiffCount); + "SELECT COUNT(1) FROM " + diffsTable, + minDiffCount); } /**