Author: tommaso
Date: Thu Jun  8 09:11:37 2017
New Revision: 1798035

URL: http://svn.apache.org/viewvc?rev=1798035&view=rev
Log:
OAK-6317 - fixed LMSEstimator update rule

Modified:
    
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
    
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java

Modified: 
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
URL: 
http://svn.apache.org/viewvc/jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java?rev=1798035&r1=1798034&r2=1798035&view=diff
==============================================================================
--- 
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
 (original)
+++ 
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
 Thu Jun  8 09:11:37 2017
@@ -22,11 +22,13 @@ import org.apache.jackrabbit.oak.spi.que
 import org.apache.solr.common.SolrDocumentList;
 
 /**
- * A very simple estimator for no. of entries in the index using least mean 
square update method but not the full stochastic
- * gradient descent algorithm (yet?), on a linear interpolation model.
+ * A very simple estimator for no. of entries in the index using least mean 
square update method for linear regression.
  */
 class LMSEstimator {
 
+    private static final double DEFAULT_ALPHA = 0.03;
+    private static final int DEFAULT_THRESHOLD = 5;
+
     private double[] weights;
     private final double alpha;
     private final long threshold;
@@ -38,23 +40,25 @@ class LMSEstimator {
     }
 
     LMSEstimator(double[] weights) {
-        this(0.03, weights, 5);
+        this(DEFAULT_ALPHA, weights, DEFAULT_THRESHOLD);
     }
 
     LMSEstimator() {
-        this(0.03, new double[5], 5);
+        this(DEFAULT_ALPHA, new double[5], 5);
     }
 
     synchronized void update(Filter filter, SolrDocumentList docs) {
         double[] updatedWeights = new double[weights.length];
+
+        // least mean square cost
         long estimate = estimate(filter);
         long numFound = docs.getNumFound();
-        long diff = numFound - estimate;
-        double delta = Math.pow(diff, 2) / 2;
+        long residual = numFound - estimate;
+        double delta = Math.pow(residual, 2);
+
         if (Math.abs(delta) > threshold) {
             for (int i = 0; i < updatedWeights.length; i++) {
-                double errors = delta * getInput(filter, i);
-                updatedWeights[i] = weights[i] + (diff > 0 ? 1 : -1) * alpha * 
errors;
+                updatedWeights[i] = weights[i] + alpha * residual * 
getInput(filter, i);
             }
             // weights updated
             weights = Arrays.copyOf(updatedWeights, 5);

Modified: 
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java
URL: 
http://svn.apache.org/viewvc/jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java?rev=1798035&r1=1798034&r2=1798035&view=diff
==============================================================================
--- 
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java
 (original)
+++ 
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java
 Thu Jun  8 09:11:37 2017
@@ -53,18 +53,22 @@ public class LMSEstimatorTest {
         docs.setNumFound(actualCount);
 
         long estimate = lmsEstimator.estimate(filter);
+        assertEquals(estimate, lmsEstimator.estimate(filter));
         long diff = actualCount - estimate;
 
         // update causes weights adjustment
         lmsEstimator.update(filter, docs);
         long estimate2 = lmsEstimator.estimate(filter);
+        assertEquals(estimate2, lmsEstimator.estimate(filter));
         long diff2 = actualCount - estimate2;
         assertTrue(diff2 < diff); // new estimate is more accurate than 
previous one
 
         // update doesn't cause weight adjustments therefore estimates stays 
unchanged
         lmsEstimator.update(filter, docs);
         long estimate3 = lmsEstimator.estimate(filter);
-        assertEquals(estimate3, estimate2);
+        assertEquals(estimate3, lmsEstimator.estimate(filter));
+        long diff3 = actualCount - estimate3;
+        assertTrue(diff3 < diff2);
     }
 
     @Test


Reply via email to