This is an automated email from the ASF dual-hosted git repository. dweiss pushed a commit to branch jira/solr-13105-toMerge in repository https://gitbox.apache.org/repos/asf/solr.git
commit 7c03cae553a0bc83d90920804c6714204eaa8391 Author: zacharymorn <[email protected]> AuthorDate: Mon Jan 11 06:03:29 2021 -0800 LUCENE-9346: Support minimumNumberShouldMatch in WANDScorer (#2141) Co-authored-by: Adrien Grand <[email protected]> --- .../lucene/search/Boolean2ScorerSupplier.java | 9 +- .../lucene/search/MinShouldMatchSumScorer.java | 30 +-- .../java/org/apache/lucene/search/ScorerUtil.java | 49 +++++ .../java/org/apache/lucene/search/WANDScorer.java | 40 +++- .../org/apache/lucene/search/TestWANDScorer.java | 225 +++++++++++++++++++++ 5 files changed, 309 insertions(+), 44 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java index 3fa5886..d0a4dbd 100644 --- a/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/search/Boolean2ScorerSupplier.java @@ -75,7 +75,7 @@ final class Boolean2ScorerSupplier extends ScorerSupplier { } else { final Collection<ScorerSupplier> optionalScorers = subs.get(Occur.SHOULD); final long shouldCost = - MinShouldMatchSumScorer.cost( + ScorerUtil.costWithMinShouldMatch( optionalScorers.stream().mapToLong(ScorerSupplier::cost), optionalScorers.size(), minShouldMatch); @@ -230,10 +230,11 @@ final class Boolean2ScorerSupplier extends ScorerSupplier { for (ScorerSupplier scorer : optional) { optionalScorers.add(scorer.get(leadCost)); } - if (minShouldMatch > 1) { + + if (scoreMode == ScoreMode.TOP_SCORES) { + return new WANDScorer(weight, optionalScorers, minShouldMatch); + } else if (minShouldMatch > 1) { return new MinShouldMatchSumScorer(weight, optionalScorers, minShouldMatch); - } else if (scoreMode == ScoreMode.TOP_SCORES) { - return new WANDScorer(weight, optionalScorers); } else { return new DisjunctionSumScorer(weight, optionalScorers, scoreMode); } diff --git a/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java b/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java index bdcdca9..574fd1a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/MinShouldMatchSumScorer.java @@ -24,9 +24,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.stream.LongStream; -import java.util.stream.StreamSupport; -import org.apache.lucene.util.PriorityQueue; /** * A {@link Scorer} for {@link BooleanQuery} when {@link @@ -44,31 +41,6 @@ import org.apache.lucene.util.PriorityQueue; */ final class MinShouldMatchSumScorer extends Scorer { - static long cost(LongStream costs, int numScorers, int minShouldMatch) { - // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m - // could be rewritten to: - // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m)) - // if we assume that clauses come in ascending cost, then - // the cost of the first part is the cost of c1 (because the cost of a conjunction is - // the cost of the least costly clause) - // the cost of the second part is the cost of finding m matches among the c2...cn - // remaining clauses - // since it is a disjunction overall, the total cost is the sum of the costs of these - // two parts - - // If we recurse infinitely, we find out that the cost of a msm query is the sum of the - // costs of the num_scorers - minShouldMatch + 1 least costly scorers - final PriorityQueue<Long> pq = - new PriorityQueue<Long>(numScorers - minShouldMatch + 1) { - @Override - protected boolean lessThan(Long a, Long b) { - return a > b; - } - }; - costs.forEach(pq::insertWithOverflow); - return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum(); - } - final int minShouldMatch; // list of scorers which 'lead' the iteration and are currently @@ -111,7 +83,7 @@ final class MinShouldMatchSumScorer extends Scorer { } this.cost = - cost( + ScorerUtil.costWithMinShouldMatch( scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost), scorers.size(), minShouldMatch); diff --git a/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java new file mode 100644 index 0000000..50c9607 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/ScorerUtil.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.util.stream.LongStream; +import java.util.stream.StreamSupport; +import org.apache.lucene.util.PriorityQueue; + +/** Util class for Scorer related methods */ +class ScorerUtil { + static long costWithMinShouldMatch(LongStream costs, int numScorers, int minShouldMatch) { + // the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m + // could be rewritten to: + // (c1 AND (c2..cn|msm=m-1)) OR (!c1 AND (c2..cn|msm=m)) + // if we assume that clauses come in ascending cost, then + // the cost of the first part is the cost of c1 (because the cost of a conjunction is + // the cost of the least costly clause) + // the cost of the second part is the cost of finding m matches among the c2...cn + // remaining clauses + // since it is a disjunction overall, the total cost is the sum of the costs of these + // two parts + + // If we recurse infinitely, we find out that the cost of a msm query is the sum of the + // costs of the num_scorers - minShouldMatch + 1 least costly scorers + final PriorityQueue<Long> pq = + new PriorityQueue<Long>(numScorers - minShouldMatch + 1) { + @Override + protected boolean lessThan(Long a, Long b) { + return a > b; + } + }; + costs.forEach(pq::insertWithOverflow); + return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum(); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java index 2c94159..b1ed3bf 100644 --- a/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/WANDScorer.java @@ -19,6 +19,7 @@ package org.apache.lucene.search; import static org.apache.lucene.search.DisiPriorityQueue.leftNode; import static org.apache.lucene.search.DisiPriorityQueue.parentNode; import static org.apache.lucene.search.DisiPriorityQueue.rightNode; +import static org.apache.lucene.search.ScorerUtil.costWithMinShouldMatch; import java.io.IOException; import java.util.ArrayList; @@ -130,10 +131,21 @@ final class WANDScorer extends Scorer { int upTo; // upper bound for which max scores are valid - WANDScorer(Weight weight, Collection<Scorer> scorers) throws IOException { + final int minShouldMatch; + int freq; + + WANDScorer(Weight weight, Collection<Scorer> scorers, int minShouldMatch) throws IOException { super(weight); + if (minShouldMatch >= scorers.size()) { + throw new IllegalArgumentException("minShouldMatch should be < the number of scorers"); + } + this.minCompetitiveScore = 0; + + assert minShouldMatch >= 0 : "minShouldMatch should not be negative, but got " + minShouldMatch; + this.minShouldMatch = minShouldMatch; + this.doc = -1; this.upTo = -1; // will be computed on the first call to nextDoc/advance @@ -155,13 +167,15 @@ final class WANDScorer extends Scorer { // Use a scaling factor of 0 if all max scores are either 0 or +Infty this.scalingFactor = scalingFactor.orElse(0); - long cost = 0; for (Scorer scorer : scorers) { - DisiWrapper w = new DisiWrapper(scorer); - cost += w.cost; - addLead(w); + addLead(new DisiWrapper(scorer)); } - this.cost = cost; + + this.cost = + costWithMinShouldMatch( + scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost), + scorers.size(), + minShouldMatch); this.maxScorePropagator = new MaxScoreSumPropagator(scorers); } @@ -265,15 +279,17 @@ final class WANDScorer extends Scorer { @Override public boolean matches() throws IOException { - while (leadMaxScore < minCompetitiveScore) { - if (leadMaxScore + tailMaxScore >= minCompetitiveScore) { + while (leadMaxScore < minCompetitiveScore || freq < minShouldMatch) { + if (leadMaxScore + tailMaxScore < minCompetitiveScore + || freq + tailSize < minShouldMatch) { + return false; + } else { // a match on doc is still possible, try to // advance scorers from the tail advanceTail(); - } else { - return false; } } + return true; } @@ -290,6 +306,7 @@ final class WANDScorer extends Scorer { lead.next = this.lead; this.lead = lead; leadMaxScore += lead.maxScore; + freq += 1; } /** Move disis that are in 'lead' back to the tail. */ @@ -429,6 +446,7 @@ final class WANDScorer extends Scorer { lead = head.pop(); lead.next = null; leadMaxScore = lead.maxScore; + freq = 1; doc = lead.doc; while (head.size() > 0 && head.top().doc == doc) { addLead(head.pop()); @@ -437,7 +455,7 @@ final class WANDScorer extends Scorer { /** Move iterators to the tail until there is a potential match. */ private int doNextCompetitiveCandidate() throws IOException { - while (leadMaxScore + tailMaxScore < minCompetitiveScore) { + while (leadMaxScore + tailMaxScore < minCompetitiveScore || freq + tailSize < minShouldMatch) { // no match on doc is possible, move to the next potential match pushBackLeads(doc + 1); moveToNextCandidate(doc + 1); diff --git a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java index 543237b..c9381fe 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestWANDScorer.java @@ -229,6 +229,231 @@ public class TestWANDScorer extends LuceneTestCase { dir.close(); } + public void testBasicsWithDisjunctionAndMinShouldMatch() throws Exception { + try (Directory dir = newDirectory()) { + try (IndexWriter w = + new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) { + for (String[] values : + Arrays.asList( + new String[] {"A", "B"}, // 0 + new String[] {"A"}, // 1 + new String[] {}, // 2 + new String[] {"A", "B", "C"}, // 3 + new String[] {"B"}, // 4 + new String[] {"B", "C"} // 5 + )) { + Document doc = new Document(); + for (String value : values) { + doc.add(new StringField("foo", value, Store.NO)); + } + w.addDocument(doc); + } + + w.forceMerge(1); + } + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + new BooleanQuery.Builder() + .add( + new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2), + Occur.SHOULD) + .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))), Occur.SHOULD) + .add( + new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "C"))), 3), + Occur.SHOULD) + .setMinimumNumberShouldMatch(2) + .build(); + + Scorer scorer = + searcher + .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1) + .scorer(searcher.getIndexReader().leaves().get(0)); + + assertEquals(0, scorer.iterator().nextDoc()); + assertEquals(2 + 1, scorer.score(), 0); + + assertEquals(3, scorer.iterator().nextDoc()); + assertEquals(2 + 1 + 3, scorer.score(), 0); + + assertEquals(5, scorer.iterator().nextDoc()); + assertEquals(1 + 3, scorer.score(), 0); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc()); + + scorer = + searcher + .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1) + .scorer(searcher.getIndexReader().leaves().get(0)); + scorer.setMinCompetitiveScore(4); + + assertEquals(3, scorer.iterator().nextDoc()); + assertEquals(2 + 1 + 3, scorer.score(), 0); + + assertEquals(5, scorer.iterator().nextDoc()); + assertEquals(1 + 3, scorer.score(), 0); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc()); + + scorer = + searcher + .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1) + .scorer(searcher.getIndexReader().leaves().get(0)); + + assertEquals(0, scorer.iterator().nextDoc()); + assertEquals(2 + 1, scorer.score(), 0); + + scorer.setMinCompetitiveScore(10); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc()); + } + } + } + + public void testBasicsWithFilteredDisjunctionAndMinShouldMatch() throws Exception { + try (Directory dir = newDirectory()) { + try (IndexWriter w = + new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) { + for (String[] values : + Arrays.asList( + new String[] {"A", "B"}, // 0 + new String[] {"A", "C", "D"}, // 1 + new String[] {}, // 2 + new String[] {"A", "B", "C", "D"}, // 3 + new String[] {"B"}, // 4 + new String[] {"C", "D"} // 5 + )) { + Document doc = new Document(); + for (String value : values) { + doc.add(new StringField("foo", value, Store.NO)); + } + w.addDocument(doc); + } + + w.forceMerge(1); + } + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + new BooleanQuery.Builder() + .add( + new BooleanQuery.Builder() + .add( + new BoostQuery( + new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2), + Occur.SHOULD) + .add( + new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))), + Occur.SHOULD) + .add( + new BoostQuery( + new ConstantScoreQuery(new TermQuery(new Term("foo", "D"))), 4), + Occur.SHOULD) + .setMinimumNumberShouldMatch(2) + .build(), + Occur.MUST) + .add(new TermQuery(new Term("foo", "C")), Occur.FILTER) + .build(); + + Scorer scorer = + searcher + .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1) + .scorer(searcher.getIndexReader().leaves().get(0)); + + assertEquals(1, scorer.iterator().nextDoc()); + assertEquals(2 + 4, scorer.score(), 0); + + assertEquals(3, scorer.iterator().nextDoc()); + assertEquals(2 + 1 + 4, scorer.score(), 0); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc()); + + scorer = + searcher + .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1) + .scorer(searcher.getIndexReader().leaves().get(0)); + + scorer.setMinCompetitiveScore(2 + 1 + 4); + + assertEquals(3, scorer.iterator().nextDoc()); + assertEquals(2 + 1 + 4, scorer.score(), 0); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc()); + } + } + } + + public void testBasicsWithFilteredDisjunctionAndMustNotAndMinShouldMatch() throws Exception { + try (Directory dir = newDirectory()) { + try (IndexWriter w = + new IndexWriter(dir, newIndexWriterConfig().setMergePolicy(newLogMergePolicy()))) { + for (String[] values : + Arrays.asList( + new String[] {"A", "B"}, // 0 + new String[] {"A", "C", "D"}, // 1 + new String[] {}, // 2 + new String[] {"A", "B", "C", "D"}, // 3 + new String[] {"B", "D"}, // 4 + new String[] {"C", "D"} // 5 + )) { + Document doc = new Document(); + for (String value : values) { + doc.add(new StringField("foo", value, Store.NO)); + } + w.addDocument(doc); + } + + w.forceMerge(1); + } + + try (IndexReader reader = DirectoryReader.open(dir)) { + IndexSearcher searcher = newSearcher(reader); + + Query query = + new BooleanQuery.Builder() + .add( + new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "A"))), 2), + Occur.SHOULD) + .add(new ConstantScoreQuery(new TermQuery(new Term("foo", "B"))), Occur.SHOULD) + .add(new TermQuery(new Term("foo", "C")), Occur.MUST_NOT) + .add( + new BoostQuery(new ConstantScoreQuery(new TermQuery(new Term("foo", "D"))), 4), + Occur.SHOULD) + .setMinimumNumberShouldMatch(2) + .build(); + + Scorer scorer = + searcher + .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1) + .scorer(searcher.getIndexReader().leaves().get(0)); + + assertEquals(0, scorer.iterator().nextDoc()); + assertEquals(2 + 1, scorer.score(), 0); + + assertEquals(4, scorer.iterator().nextDoc()); + assertEquals(1 + 4, scorer.score(), 0); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc()); + + scorer = + searcher + .createWeight(searcher.rewrite(query), ScoreMode.TOP_SCORES, 1) + .scorer(searcher.getIndexReader().leaves().get(0)); + + scorer.setMinCompetitiveScore(4); + + assertEquals(4, scorer.iterator().nextDoc()); + assertEquals(1 + 4, scorer.score(), 0); + + assertEquals(DocIdSetIterator.NO_MORE_DOCS, scorer.iterator().nextDoc()); + } + } + } + public void testRandom() throws IOException { Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
