Author: tommaso
Date: Thu Jun 8 09:11:37 2017
New Revision: 1798035
URL: http://svn.apache.org/viewvc?rev=1798035&view=rev
Log:
OAK-6317 - fixed LMSEstimator update rule
Modified:
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java
Modified:
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
URL:
http://svn.apache.org/viewvc/jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java?rev=1798035&r1=1798034&r2=1798035&view=diff
==============================================================================
---
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
(original)
+++
jackrabbit/oak/trunk/oak-solr-core/src/main/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimator.java
Thu Jun 8 09:11:37 2017
@@ -22,11 +22,13 @@ import org.apache.jackrabbit.oak.spi.que
import org.apache.solr.common.SolrDocumentList;
/**
- * A very simple estimator for no. of entries in the index using least mean
square update method but not the full stochastic
- * gradient descent algorithm (yet?), on a linear interpolation model.
+ * A very simple estimator for no. of entries in the index using least mean
square update method for linear regression.
*/
class LMSEstimator {
+ private static final double DEFAULT_ALPHA = 0.03;
+ private static final int DEFAULT_THRESHOLD = 5;
+
private double[] weights;
private final double alpha;
private final long threshold;
@@ -38,23 +40,25 @@ class LMSEstimator {
}
LMSEstimator(double[] weights) {
- this(0.03, weights, 5);
+ this(DEFAULT_ALPHA, weights, DEFAULT_THRESHOLD);
}
LMSEstimator() {
- this(0.03, new double[5], 5);
+ this(DEFAULT_ALPHA, new double[5], 5);
}
synchronized void update(Filter filter, SolrDocumentList docs) {
double[] updatedWeights = new double[weights.length];
+
+ // least mean square cost
long estimate = estimate(filter);
long numFound = docs.getNumFound();
- long diff = numFound - estimate;
- double delta = Math.pow(diff, 2) / 2;
+ long residual = numFound - estimate;
+ double delta = Math.pow(residual, 2);
+
if (Math.abs(delta) > threshold) {
for (int i = 0; i < updatedWeights.length; i++) {
- double errors = delta * getInput(filter, i);
- updatedWeights[i] = weights[i] + (diff > 0 ? 1 : -1) * alpha *
errors;
+ updatedWeights[i] = weights[i] + alpha * residual *
getInput(filter, i);
}
// weights updated
weights = Arrays.copyOf(updatedWeights, 5);
Modified:
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java
URL:
http://svn.apache.org/viewvc/jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java?rev=1798035&r1=1798034&r2=1798035&view=diff
==============================================================================
---
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java
(original)
+++
jackrabbit/oak/trunk/oak-solr-core/src/test/java/org/apache/jackrabbit/oak/plugins/index/solr/query/LMSEstimatorTest.java
Thu Jun 8 09:11:37 2017
@@ -53,18 +53,22 @@ public class LMSEstimatorTest {
docs.setNumFound(actualCount);
long estimate = lmsEstimator.estimate(filter);
+ assertEquals(estimate, lmsEstimator.estimate(filter));
long diff = actualCount - estimate;
// update causes weights adjustment
lmsEstimator.update(filter, docs);
long estimate2 = lmsEstimator.estimate(filter);
+ assertEquals(estimate2, lmsEstimator.estimate(filter));
long diff2 = actualCount - estimate2;
assertTrue(diff2 < diff); // new estimate is more accurate than
previous one
// update doesn't cause weight adjustments therefore estimates stays
unchanged
lmsEstimator.update(filter, docs);
long estimate3 = lmsEstimator.estimate(filter);
- assertEquals(estimate3, estimate2);
+ assertEquals(estimate3, lmsEstimator.estimate(filter));
+ long diff3 = actualCount - estimate3;
+ assertTrue(diff3 < diff2);
}
@Test