Author: srowen
Date: Sat Jan 15 16:28:49 2011
New Revision: 1059365

URL: http://svn.apache.org/viewvc?rev=1059365&view=rev
Log:
MAHOUT-576

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/recommender/slopeone/DiffStorage.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java?rev=1059365&r1=1059364&r2=1059365&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java
 Sat Jan 15 16:28:49 2011
@@ -69,10 +69,6 @@ public final class MemoryDiffStorage imp
   
   /**
    * <p>
-   * Creates a new .
-   * </p>
-   * 
-   * <p>
    * See {@link 
org.apache.mahout.cf.taste.impl.recommender.slopeone.SlopeOneRecommender} for 
the meaning of
    * <code>stdDevWeighted</code>. If <code>compactAverages</code> is set, this 
uses alternate data structures
    * ({@link CompactRunningAverage} versus {@link FullRunningAverage}) that 
use almost 50% less memory but
@@ -175,10 +171,53 @@ public final class MemoryDiffStorage imp
   public RunningAverage getAverageItemPref(long itemID) {
     return averageItemPref.get(itemID);
   }
+
+  @Override
+  public void addItemPref(long userID, long itemIDA, float prefValue) throws 
TasteException {
+    PreferenceArray userPreferences = dataModel.getPreferencesFromUser(userID);
+    try {
+      buildAverageDiffsLock.writeLock().lock();
+
+      FastByIDMap<RunningAverage> aMap = averageDiffs.get(itemIDA);
+      if (aMap == null) {
+        aMap = new FastByIDMap<RunningAverage>();
+        averageDiffs.put(itemIDA, aMap);
+      }
+
+      int length = userPreferences.length();
+      for (int i = 0; i < length; i++) {
+        long itemIDB = userPreferences.getItemID(i);
+        float bValue = userPreferences.getValue(i);
+        if (itemIDA < itemIDB) {
+          RunningAverage average = aMap.get(itemIDB);
+          if (average == null) {
+            average = buildRunningAverage();
+            aMap.put(itemIDB, average);
+          }
+          average.addDatum(bValue - prefValue);
+        } else {
+          FastByIDMap<RunningAverage> bMap = averageDiffs.get(itemIDB);
+          if (bMap == null) {
+            bMap = new FastByIDMap<RunningAverage>();
+            averageDiffs.put(itemIDB, bMap);
+          }
+          RunningAverage average = bMap.get(itemIDA);
+          if (average == null) {
+            average = buildRunningAverage();
+            bMap.put(itemIDA, average);
+          }
+          average.addDatum(prefValue - bValue);
+        }
+      }
+
+    } finally {
+      buildAverageDiffsLock.writeLock().unlock();
+    }
+  }
   
   @Override
-  public void updateItemPref(long itemID, float prefDelta, boolean remove) {
-    if (!remove && stdDevWeighted) {
+  public void updateItemPref(long itemID, float prefDelta) {
+    if (stdDevWeighted) {
       throw new UnsupportedOperationException("Can't update only when 
stdDevWeighted is set");
     }
     try {
@@ -188,17 +227,9 @@ public final class MemoryDiffStorage imp
         for (Map.Entry<Long,RunningAverage> entry2 : 
entry.getValue().entrySet()) {
           RunningAverage average = entry2.getValue();
           if (matchesItemID1) {
-            if (remove) {
-              average.removeDatum(prefDelta);
-            } else {
-              average.changeDatum(-prefDelta);
-            }
+            average.changeDatum(-prefDelta);
           } else if (itemID == entry2.getKey()) {
-            if (remove) {
-              average.removeDatum(-prefDelta);
-            } else {
-              average.changeDatum(prefDelta);
-            }
+            average.changeDatum(prefDelta);
           }
         }
       }
@@ -210,6 +241,55 @@ public final class MemoryDiffStorage imp
       buildAverageDiffsLock.readLock().unlock();
     }
   }
+
+  @Override
+  public void removeItemPref(long userID, long itemIDA, float prefValue) 
throws TasteException {
+    PreferenceArray userPreferences = dataModel.getPreferencesFromUser(userID);
+    try {
+      buildAverageDiffsLock.writeLock().lock();
+
+      FastByIDMap<RunningAverage> aMap = averageDiffs.get(itemIDA);
+
+      int length = userPreferences.length();
+      for (int i = 0; i < length; i++) {
+
+        long itemIDB = userPreferences.getItemID(i);
+        float bValue = userPreferences.getValue(i);
+
+        if (itemIDA < itemIDB) {
+
+          if (aMap != null) {
+            RunningAverage average = aMap.get(itemIDB);
+            if (average != null) {
+              if (average.getCount() <= 1) {
+                aMap.remove(itemIDB);
+              } else {
+                average.removeDatum(bValue - prefValue);
+              }
+            }
+          }
+
+        } else  if (itemIDA > itemIDB) {
+
+          FastByIDMap<RunningAverage> bMap = averageDiffs.get(itemIDB);
+          if (bMap != null) {
+            RunningAverage average = bMap.get(itemIDA);
+            if (average != null) {
+              if (average.getCount() <= 1) {
+                aMap.remove(itemIDA);
+              } else {
+                average.removeDatum(prefValue - bValue);
+              }
+            }
+          }
+
+        }
+      }
+
+    } finally {
+      buildAverageDiffsLock.writeLock().unlock();
+    }
+  }
   
   @Override
   public FastIDSet getRecommendableItemIDs(long userID) throws TasteException {

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java?rev=1059365&r1=1059364&r2=1059365&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java
 Sat Jan 15 16:28:49 2011
@@ -172,15 +172,20 @@ public final class SlopeOneRecommender e
   @Override
   public void setPreference(long userID, long itemID, float value) throws 
TasteException {
     DataModel dataModel = getDataModel();
-    float prefDelta;
+    Float oldPref;
     try {
-      Float oldPref = dataModel.getPreferenceValue(userID, itemID);
-      prefDelta = oldPref == null ? value : value - oldPref;
+      oldPref = dataModel.getPreferenceValue(userID, itemID);
     } catch (NoSuchUserException nsee) {
-      prefDelta = value;
+      oldPref = null;
     }
     super.setPreference(userID, itemID, value);
-    diffStorage.updateItemPref(itemID, prefDelta, false);
+    if (oldPref == null) {
+      // Add new preference
+      diffStorage.addItemPref(userID, itemID, value);
+    } else {
+      // Update preference
+      diffStorage.updateItemPref(itemID, value - oldPref);
+    }
   }
   
   @Override
@@ -189,7 +194,7 @@ public final class SlopeOneRecommender e
     Float oldPref = dataModel.getPreferenceValue(userID, itemID);
     super.removePreference(userID, itemID);
     if (oldPref != null) {
-      diffStorage.updateItemPref(itemID, oldPref, true);
+      diffStorage.removeItemPref(userID, itemID, oldPref);
     }
   }
   

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java?rev=1059365&r1=1059364&r2=1059365&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java
 Sat Jan 15 16:28:49 2011
@@ -250,9 +250,15 @@ public final class FileDiffStorage imple
   public RunningAverage getAverageItemPref(long itemID) {
     return null; // TODO can't do this without a DataModel
   }
-  
+
+  @Override
+  public void addItemPref(long userID, long itemIDA, float prefValue) {
+    // Can't do this without a DataModel; should it just be a no-op?
+    throw new UnsupportedOperationException();
+  }
+
   @Override
-  public void updateItemPref(long itemID, float prefDelta, boolean remove) {
+  public void updateItemPref(long itemID, float prefDelta) {
     try {
       buildAverageDiffsLock.readLock().lock();
       for (Map.Entry<Long,FastByIDMap<RunningAverage>> entry : 
averageDiffs.entrySet()) {
@@ -260,17 +266,9 @@ public final class FileDiffStorage imple
         for (Map.Entry<Long,RunningAverage> entry2 : 
entry.getValue().entrySet()) {
           RunningAverage average = entry2.getValue();
           if (matchesItemID1) {
-            if (remove) {
-              average.removeDatum(prefDelta);
-            } else {
-              average.changeDatum(-prefDelta);
-            }
+            average.changeDatum(-prefDelta);
           } else if (itemID == entry2.getKey()) {
-            if (remove) {
-              average.removeDatum(-prefDelta);
-            } else {
-              average.changeDatum(prefDelta);
-            }
+            average.changeDatum(prefDelta);
           }
         }
       }
@@ -282,6 +280,12 @@ public final class FileDiffStorage imple
       buildAverageDiffsLock.readLock().unlock();
     }
   }
+
+  @Override
+  public void removeItemPref(long userID, long itemIDA, float prefValue) {
+    // Can't do this without a DataModel; should it just be a no-op?
+    throw new UnsupportedOperationException();
+  }
   
   @Override
   public FastIDSet getRecommendableItemIDs(long userID) {

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java?rev=1059365&r1=1059364&r2=1059365&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
 Sat Jan 15 16:28:49 2011
@@ -63,12 +63,16 @@ public abstract class AbstractJDBCDiffSt
   public static final String DEFAULT_AVERAGE_DIFF_COLUMN = "average_diff";
   public static final String DEFAULT_STDEV_COLUMN = "standard_deviation";
 
+  private final JDBCDataModel dataModel;
   private final DataSource dataSource;
   private final String getDiffSQL;
   private final String getDiffsSQL;
   private final String getAverageItemPrefSQL;
+  private final String getDiffsAffectedByUserSQL;
   private final String[] updateDiffSQLs;
-  private final String[] removeDiffSQLs;
+  private final String updateOneDiffSQL;
+  private final String addDiffSQL;
+  private final String removeDiffSQL;
   private final String getRecommendableItemsSQL;
   private final String deleteDiffsSQL;
   private final String createDiffsSQL;
@@ -80,8 +84,11 @@ public abstract class AbstractJDBCDiffSt
                                     String getDiffSQL,
                                     String getDiffsSQL,
                                     String getAverageItemPrefSQL,
+                                    String getDiffsAffectedByUserSQL,
                                     String[] updateDiffSQLs,
-                                    String[] removeDiffSQLs,
+                                    String updateOneDiffSQL,
+                                    String addDiffSQL,
+                                    String removeDiffSQL,
                                     String getRecommendableItemsSQL,
                                     String deleteDiffsSQL,
                                     String createDiffsSQL,
@@ -92,21 +99,28 @@ public abstract class AbstractJDBCDiffSt
     AbstractJDBCComponent.checkNotNullAndLog("getDiffSQL", getDiffSQL);
     AbstractJDBCComponent.checkNotNullAndLog("getDiffsSQL", getDiffsSQL);
     AbstractJDBCComponent.checkNotNullAndLog("getAverageItemPrefSQL", 
getAverageItemPrefSQL);
+    AbstractJDBCComponent.checkNotNullAndLog("getDiffsAffectedByUserSQL", 
getDiffsAffectedByUserSQL);
     AbstractJDBCComponent.checkNotNullAndLog("updateDiffSQLs", updateDiffSQLs);
-    AbstractJDBCComponent.checkNotNullAndLog("removeDiffSQLs", removeDiffSQLs);
+    AbstractJDBCComponent.checkNotNullAndLog("updateOneDiffSQL", 
updateOneDiffSQL);
+    AbstractJDBCComponent.checkNotNullAndLog("addDiffSQL", addDiffSQL);
+    AbstractJDBCComponent.checkNotNullAndLog("removeDiffSQL", removeDiffSQL);
     AbstractJDBCComponent.checkNotNullAndLog("getRecommendableItemsSQL", 
getRecommendableItemsSQL);
     AbstractJDBCComponent.checkNotNullAndLog("deleteDiffsSQL", deleteDiffsSQL);
     AbstractJDBCComponent.checkNotNullAndLog("createDiffsSQL", createDiffsSQL);
     AbstractJDBCComponent.checkNotNullAndLog("diffsExistSQL", diffsExistSQL);
 
     Preconditions.checkArgument(minDiffCount >= 0, "minDiffCount is not 
positive");
-    
+
+    this.dataModel = dataModel;
     this.dataSource = dataModel.getDataSource();
     this.getDiffSQL = getDiffSQL;
     this.getDiffsSQL = getDiffsSQL;
     this.getAverageItemPrefSQL = getAverageItemPrefSQL;
+    this.getDiffsAffectedByUserSQL = getDiffsAffectedByUserSQL;
     this.updateDiffSQLs = updateDiffSQLs;
-    this.removeDiffSQLs = removeDiffSQLs;
+    this.updateOneDiffSQL = updateOneDiffSQL;
+    this.addDiffSQL = addDiffSQL;
+    this.removeDiffSQL = removeDiffSQL;
     this.getRecommendableItemsSQL = getRecommendableItemsSQL;
     this.deleteDiffsSQL = deleteDiffsSQL;
     this.createDiffsSQL = createDiffsSQL;
@@ -223,23 +237,107 @@ public abstract class AbstractJDBCDiffSt
     }
   }
 
+
+  @Override
+  public void addItemPref(long userID, long itemID, float prefValue) throws 
TasteException {
+
+    PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
+    FastIDSet unupdatedItemIDs = new FastIDSet();
+    for (long anItemID : prefs.getIDs()) {
+      unupdatedItemIDs.add(anItemID);
+    }
+
+    Connection conn = null;
+    PreparedStatement stmt = null;
+    ResultSet rs = null;
+    try {
+      conn = dataSource.getConnection();
+      stmt = conn.prepareStatement(getDiffsAffectedByUserSQL, 
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+      stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
+      stmt.setFetchSize(getFetchSize());
+      stmt.setLong(1, userID);
+      log.debug("Executing SQL query: {}", getDiffsAffectedByUserSQL);
+      rs = stmt.executeQuery();
+
+      while (rs.next()) {
+        int count = rs.getInt(1);
+        float average = rs.getFloat(2);
+        long itemIDA = rs.getLong(3);
+        long itemIDB = rs.getLong(4);
+        float currentOtherPrefValue = rs.getFloat(5);
+        float prefDelta;
+        long otherItemID;
+        if (itemID == itemIDA) {
+          prefDelta = currentOtherPrefValue - prefValue;
+          otherItemID = itemIDB;
+        } else {
+          prefDelta = prefValue - currentOtherPrefValue;
+          otherItemID = itemIDA;
+        }
+        float newAverage = (average * count + prefDelta) / (count + 1);
+        updateOneDiff(conn, count + 1, newAverage, itemIDA, itemIDB);
+        unupdatedItemIDs.remove(otherItemID);
+      }
+
+    } catch (SQLException sqle) {
+      log.warn("Exception while adding item diff", sqle);
+      throw new TasteException(sqle);
+    } finally {
+      IOUtils.quietClose(rs, stmt, conn);
+    }
+
+    // Catch antyhing that wasn't already covered in the diff table
+    try {
+      conn = dataSource.getConnection();
+      stmt = conn.prepareStatement(addDiffSQL);
+      for (long unupdatedItemID : unupdatedItemIDs) {
+        if (unupdatedItemID < itemID) {
+          stmt.setLong(1, unupdatedItemID);
+          stmt.setLong(2, itemID);
+          stmt.setFloat(3, prefValue);
+        } else {
+          stmt.setLong(1, itemID);
+          stmt.setLong(2, unupdatedItemID);
+          stmt.setFloat(3, -prefValue);
+        }
+        log.debug("Executing SQL query: {}", getDiffsAffectedByUserSQL);
+        stmt.executeUpdate();
+      }
+    } catch (SQLException sqle) {
+      log.warn("Exception while adding item diff", sqle);
+      throw new TasteException(sqle);
+    } finally {
+      IOUtils.quietClose(null, stmt, conn);
+    }
+  }
+
+  private void updateOneDiff(Connection conn, int newCount, float newAverage, 
long itemIDA, long itemIDB)
+    throws SQLException {
+    PreparedStatement stmt = conn.prepareStatement(updateOneDiffSQL);
+    try {
+      stmt.setInt(1, newCount);
+      stmt.setFloat(2, newAverage);
+      stmt.setLong(3, itemIDA);
+      stmt.setLong(4, itemIDB);
+      log.debug("Executing SQL update: {}", updateOneDiffSQL);
+      stmt.executeUpdate();
+    } finally {
+      IOUtils.quietClose(stmt);
+    }
+  }
+
   /**
    * Note that this implementation does <em>not</em> update standard 
deviations. This would
    * be expensive relative to the value of slightly adjusting these values, 
which are merely
    * used as weighted. Rebuilding the diffs table will update standard 
deviations.
    */
   @Override
-  public void updateItemPref(long itemID, float prefDelta, boolean remove) 
throws TasteException {
+  public void updateItemPref(long itemID, float prefDelta) throws 
TasteException {
     Connection conn = null;
     try {
       conn = dataSource.getConnection();
-      if (remove) {
-        doPartialUpdate(removeDiffSQLs[0], itemID, prefDelta, conn);
-        doPartialUpdate(removeDiffSQLs[1], itemID, prefDelta, conn);
-      } else {
-        doPartialUpdate(updateDiffSQLs[0], itemID, prefDelta, conn);
-        doPartialUpdate(updateDiffSQLs[1], itemID, prefDelta, conn);
-      }
+      doPartialUpdate(updateDiffSQLs[0], itemID, prefDelta, conn);
+      doPartialUpdate(updateDiffSQLs[1], itemID, prefDelta, conn);
     } catch (SQLException sqle) {
       log.warn("Exception while updating item diff", sqle);
       throw new TasteException(sqle);
@@ -247,6 +345,61 @@ public abstract class AbstractJDBCDiffSt
       IOUtils.quietClose(conn);
     }
   }
+
+  @Override
+  public void removeItemPref(long userID, long itemID, float prefValue) throws 
TasteException {
+    Connection conn = null;
+    PreparedStatement stmt = null;
+    ResultSet rs = null;
+    try {
+      conn = dataSource.getConnection();
+      stmt = conn.prepareStatement(getDiffsAffectedByUserSQL, 
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+      stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
+      stmt.setFetchSize(getFetchSize());
+      stmt.setLong(1, userID);
+      log.debug("Executing SQL query: {}", getDiffsAffectedByUserSQL);
+      rs = stmt.executeQuery();
+
+      while (rs.next()) {
+        int count = rs.getInt(1);
+        long itemIDA = rs.getLong(3);
+        long itemIDB = rs.getLong(4);
+        if (count == minDiffCount) {
+          // going to remove the diff
+          removeOneDiff(conn, itemIDA, itemIDB);
+        } else {
+          float average = rs.getFloat(2);
+          float currentOtherPrefValue = rs.getFloat(5);
+          float prefDelta;
+          if (itemID == itemIDA) {
+            prefDelta = currentOtherPrefValue - prefValue;
+          } else {
+            prefDelta = prefValue - currentOtherPrefValue;
+          }
+          float newAverage = (average * count - prefDelta) / (count - 1);
+          updateOneDiff(conn, count - 1, newAverage, itemIDA, itemIDB);
+        }
+      }
+    } catch (SQLException sqle) {
+      log.warn("Exception while removing item diff", sqle);
+      throw new TasteException(sqle);
+    } finally {
+      IOUtils.quietClose(rs, stmt, conn);
+    }
+  }
+
+  private void removeOneDiff(Connection conn, long itemIDA, long itemIDB)
+    throws SQLException {
+    PreparedStatement stmt = conn.prepareStatement(removeDiffSQL);
+    try {
+      stmt.setLong(1, itemIDA);
+      stmt.setLong(2, itemIDB);
+      log.debug("Executing SQL update: {}", removeDiffSQL);
+      stmt.executeUpdate();
+    } finally {
+      IOUtils.quietClose(stmt);
+    }
+  }
   
   private static void doPartialUpdate(String sql, long itemID, double 
prefDelta, Connection conn) throws SQLException {
     PreparedStatement stmt = conn.prepareStatement(sql);

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java?rev=1059365&r1=1059364&r2=1059365&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
 Sat Jan 15 16:28:49 2011
@@ -125,22 +125,29 @@ public final class MySQLJDBCDiffStorage 
         // getAverageItemPrefSQL
         "SELECT COUNT(1), AVG(" + dataModel.getPreferenceColumn() + ") FROM "
             + dataModel.getPreferenceTable() + " WHERE " + 
dataModel.getItemIDColumn() + "=?",
+        // getDiffsAffectedByUserSQL
+        "SELECT diffs." + countColumn + ", diffs." + avgColumn + ", diffs." + 
itemIDAColumn
+            + ", diffs." + itemIDBColumn + ", prefs." + 
dataModel.getPreferenceColumn()
+            + " FROM " + diffsTable + " AS diffs, " + 
dataModel.getPreferenceTable() + " AS prefs WHERE prefs."
+            + dataModel.getUserIDColumn() + "=? AND (prefs." + 
dataModel.getItemIDColumn()
+            + " = diffs." + itemIDAColumn + " OR prefs." + 
dataModel.getItemIDColumn()
+            + " = diffs." + itemIDBColumn + ')',
         // updateDiffSQLs
         new String[] {
-          "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " 
- (? / "
-              + countColumn + ") WHERE " + itemIDAColumn + "=?",
-          "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " 
+ (? / "
-              + countColumn + ") WHERE " + itemIDBColumn + "=?"},
+          "UPDATE " + diffsTable + " SET "
+              + avgColumn + " = " + avgColumn + " - (? / " + countColumn
+              + ") WHERE " + itemIDAColumn + "=?",
+          "UPDATE " + diffsTable + " SET "
+              + avgColumn + " = " + avgColumn + " + (? / " + countColumn
+              + ") WHERE " + itemIDBColumn + "=?"},
+        // updateOneDiffSQL
+        "UPDATE " + diffsTable + " SET " + countColumn + "=?, " + avgColumn + 
"=? WHERE "
+            + itemIDAColumn + "=? AND " + itemIDBColumn + "=?",
+        // addDiffSQL
+        "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + 
itemIDBColumn + ", " + avgColumn
+            + ", " + stdevColumn + ", " + countColumn + ") VALUES (?,?,?,0,1)",
         // removeDiffSQL
-        new String[] {
-          "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn 
+ "-1, "
-              + avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) 
/ CAST("
-              + countColumn + " AS DECIMAL)) + ? / CAST(" + countColumn + " AS 
DECIMAL) WHERE "
-              + itemIDAColumn + "=?",
-          "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn 
+ "-1, "
-              + avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) 
/ CAST("
-              + countColumn + " AS DECIMAL)) - ? / CAST(" + countColumn + " AS 
DECIMAL) WHERE "
-              + itemIDBColumn + "=?"},
+        "DELETE FROM " + diffsTable + " WHERE " + itemIDAColumn + "=? AND " + 
itemIDBColumn + "=?",
         // getRecommendableItemsSQL
         "SELECT id FROM " + "(SELECT " + itemIDAColumn + " AS id FROM " + 
diffsTable + ", "
             + dataModel.getPreferenceTable() + " WHERE " + itemIDBColumn + " = 
"

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/recommender/slopeone/DiffStorage.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/recommender/slopeone/DiffStorage.java?rev=1059365&r1=1059364&r2=1059365&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/recommender/slopeone/DiffStorage.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/recommender/slopeone/DiffStorage.java
 Sat Jan 15 16:28:49 2011
@@ -55,20 +55,32 @@ public interface DiffStorage extends Ref
   
   /** @return {@link RunningAverage} encapsulating the average preference for 
the given item */
   RunningAverage getAverageItemPref(long itemID) throws TasteException;
-  
+
+  /**
+   * <p>Updates internal data structures to reflect a new preference value for 
an item.</p>
+   *
+   * @param userID user whose pref is being added
+   * @param itemID item to add preference value for
+   * @param prefValue new preference value
+   */
+  void addItemPref(long userID, long itemID, float prefValue) throws 
TasteException;
+
   /**
-   * <p>
-   * Updates internal data structures to reflect an update in a preference 
value for an item.
-   * </p>
+   * <p>Updates internal data structures to reflect an update in a preference 
value for an item.</p>
    * 
-   * @param itemID
-   *          item to update preference value for
-   * @param prefDelta
-   *          amount by which preference value changed (or its old value, if 
being removed
-   * @param remove
-   *          if <code>true</code>, operation reflects a removal rather than 
change of preference
+   * @param itemID item to update preference value for
+   * @param prefDelta amount by which preference value changed
+   */
+  void updateItemPref(long itemID, float prefDelta) throws TasteException;
+
+  /**
+   * <p>Updates internal data structures to reflect an update in a preference 
value for an item.</p>
+   *
+   * @param userID user whose pref is being removed
+   * @param itemID item to update preference value for
+   * @param prefValue old preference value
    */
-  void updateItemPref(long itemID, float prefDelta, boolean remove) throws 
TasteException;
+  void removeItemPref(long userID, long itemID, float prefValue) throws 
TasteException;
   
   /**
    * @return item IDs that may possibly be recommended to the given user, 
which may not be all items since the

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java?rev=1059365&r1=1059364&r2=1059365&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java
 Sat Jan 15 16:28:49 2011
@@ -24,7 +24,7 @@ import org.apache.mahout.cf.taste.model.
 import org.junit.Test;
 
 /** Tests {@link MemoryDiffStorage}. */
-public class MemoryDiffStorageTest extends TasteTestCase {
+public final class MemoryDiffStorageTest extends TasteTestCase {
 
   @Test
   public void testGetDiff() throws Exception {
@@ -36,11 +36,41 @@ public class MemoryDiffStorageTest exten
   }
 
   @Test
+  public void testAdd() throws Exception {
+    DataModel model = getDataModel();
+    MemoryDiffStorage storage = new MemoryDiffStorage(model, 
Weighting.UNWEIGHTED, false, Long.MAX_VALUE);
+
+    RunningAverage average1 = storage.getDiff(0, 2);
+    assertEquals(0.1, average1.getAverage(), EPSILON);
+    assertEquals(3, average1.getCount());
+
+    RunningAverage average2 = storage.getDiff(1, 2);
+    assertEquals(0.23333332935969034, average2.getAverage(), EPSILON);
+    assertEquals(3, average2.getCount());
+
+    storage.addItemPref(1, 2, 0.8f);
+
+    average1 = storage.getDiff(0, 2);
+    assertEquals(0.25, average1.getAverage(), EPSILON);
+    assertEquals(4, average1.getCount());
+
+    average2 = storage.getDiff(1, 2);
+    assertEquals(0.3, average2.getAverage(), EPSILON);
+    assertEquals(4, average2.getCount());
+  }
+
+  @Test
   public void testUpdate() throws Exception {
     DataModel model = getDataModel();
     MemoryDiffStorage storage = new MemoryDiffStorage(model, 
Weighting.UNWEIGHTED, false, Long.MAX_VALUE);
-    storage.updateItemPref(1, 0.5f, false);
+
     RunningAverage average = storage.getDiff(1, 2);
+    assertEquals(0.23333332935969034, average.getAverage(), EPSILON);
+    assertEquals(3, average.getCount());
+
+    storage.updateItemPref(1, 0.5f);
+
+    average = storage.getDiff(1, 2);
     assertEquals(0.06666666666666668, average.getAverage(), EPSILON);
     assertEquals(3, average.getCount());
   }
@@ -49,10 +79,24 @@ public class MemoryDiffStorageTest exten
   public void testRemove() throws Exception {
     DataModel model = getDataModel();
     MemoryDiffStorage storage = new MemoryDiffStorage(model, 
Weighting.UNWEIGHTED, false, Long.MAX_VALUE);
-    storage.updateItemPref(1, 0.5f, true);
-    RunningAverage average = storage.getDiff(1, 2);
-    assertEquals(0.1, average.getAverage(), EPSILON);
-    assertEquals(2, average.getCount());
+
+    RunningAverage average1 = storage.getDiff(0, 2);
+    assertEquals(0.1, average1.getAverage(), EPSILON);
+    assertEquals(3, average1.getCount());
+
+    RunningAverage average2 = storage.getDiff(1, 2);
+    assertEquals(0.23333332935969034, average2.getAverage(), EPSILON);
+    assertEquals(3, average2.getCount());
+
+    storage.removeItemPref(4, 2, 0.8f);
+
+    average1 = storage.getDiff(0, 2);
+    assertEquals(0.1, average1.getAverage(), EPSILON);
+    assertEquals(2, average1.getCount());
+
+    average2 = storage.getDiff(1, 2);
+    assertEquals(0.1, average2.getAverage(), EPSILON);
+    assertEquals(2, average2.getCount());
   }
 
 }


Reply via email to