fabriziofortino commented on code in PR #1217: URL: https://github.com/apache/jackrabbit-oak/pull/1217#discussion_r1397515632
########## oak-search-elastic/src/main/java/org/apache/jackrabbit/oak/plugins/index/elastic/query/async/facets/ElasticStatisticalFacetAsyncProvider.java: ########## @@ -18,145 +18,178 @@ import co.elastic.clients.elasticsearch._types.aggregations.Aggregate; import co.elastic.clients.elasticsearch._types.aggregations.StringTermsBucket; +import co.elastic.clients.elasticsearch._types.query_dsl.BoolQuery; +import co.elastic.clients.elasticsearch._types.query_dsl.Query; +import co.elastic.clients.elasticsearch.core.SearchRequest; +import co.elastic.clients.elasticsearch.core.SearchResponse; import co.elastic.clients.elasticsearch.core.search.Hit; +import co.elastic.clients.elasticsearch.core.search.SourceConfig; +import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.jackrabbit.oak.plugins.index.elastic.ElasticConnection; +import org.apache.jackrabbit.oak.plugins.index.elastic.ElasticIndexDefinition; import org.apache.jackrabbit.oak.plugins.index.elastic.query.ElasticRequestHandler; import org.apache.jackrabbit.oak.plugins.index.elastic.query.ElasticResponseHandler; -import org.apache.jackrabbit.oak.plugins.index.elastic.query.async.ElasticResponseListener; +import org.apache.jackrabbit.oak.plugins.index.search.FieldNames; import org.apache.jackrabbit.oak.plugins.index.search.spi.query.FulltextIndex; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.ArrayList; import java.util.HashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; -import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.function.Predicate; import java.util.stream.Collectors; /** - * An {@link ElasticSecureFacetAsyncProvider} extension that subscribes also on Elastic Aggregation events. + * An {@link ElasticFacetProvider} extension that performs random sampling on the result set to compute facets. * SearchHit events are sampled and then used to adjust facets coming from Aggregations in order to minimize - * access checks. This provider could improve facets performance but only when the result set is quite big. + * access checks. This provider could improve facets performance especially when the result set is quite big. */ -public class ElasticStatisticalFacetAsyncProvider extends ElasticSecureFacetAsyncProvider - implements ElasticResponseListener.AggregationListener { +public class ElasticStatisticalFacetAsyncProvider implements ElasticFacetProvider { - private final int sampleSize; + private static final Logger LOG = LoggerFactory.getLogger(ElasticStatisticalFacetAsyncProvider.class); + + private final ElasticResponseHandler elasticResponseHandler; + private final Predicate<String> isAccessible; + private final Set<String> facetFields; + private final Map<String, List<FulltextIndex.Facet>> allFacets = new HashMap<>(); + private final Map<String, Map<String, Integer>> accessibleFacetCounts = new ConcurrentHashMap<>(); + private Map<String, List<FulltextIndex.Facet>> facets; + private final CountDownLatch latch = new CountDownLatch(1); + private int sampled; private long totalHits; - private final Random rGen; - private int sampled = 0; - private int seen = 0; - private long accessibleCount = 0; + ElasticStatisticalFacetAsyncProvider(ElasticConnection connection, ElasticIndexDefinition indexDefinition, + ElasticRequestHandler elasticRequestHandler, ElasticResponseHandler elasticResponseHandler, + Predicate<String> isAccessible, long randomSeed, int sampleSize) { - private final Map<String, List<FulltextIndex.Facet>> facetMap = new HashMap<>(); + this.elasticResponseHandler = elasticResponseHandler; + this.isAccessible = isAccessible; + this.facetFields = elasticRequestHandler.facetFields().collect(Collectors.toSet()); - private final CountDownLatch latch = new CountDownLatch(1); + BoolQuery.Builder builder = elasticRequestHandler.baseQueryBuilder(); + builder.should(sb -> sb.functionScore(fsb -> + fsb.functions(f -> f.randomScore(rsb -> rsb.seed("" + randomSeed).field(FieldNames.PATH))) + )); - ElasticStatisticalFacetAsyncProvider(ElasticRequestHandler elasticRequestHandler, - ElasticResponseHandler elasticResponseHandler, - Predicate<String> isAccessible, - long randomSeed, int sampleSize) { - super(elasticRequestHandler, elasticResponseHandler, isAccessible); - this.sampleSize = sampleSize; - this.rGen = new Random(randomSeed); - } + SearchRequest searchRequest = SearchRequest.of(srb -> srb.index(indexDefinition.getIndexAlias()) + .trackTotalHits(thb -> thb.enabled(true)) + .source(SourceConfig.of(scf -> scf.filter(ff -> ff.includes(FieldNames.PATH).includes(new ArrayList<>(facetFields))))) + .query(Query.of(qb -> qb.bool(builder.build()))) + .aggregations(elasticRequestHandler.aggregations()) + .size(sampleSize) + ); - @Override - public void startData(long totalHits) { - this.totalHits = totalHits; + LOG.trace("Kicking search query with random sampling {}", searchRequest); + CompletableFuture<SearchResponse<ObjectNode>> searchFuture = + connection.getAsyncClient().search(searchRequest, ObjectNode.class); + + searchFuture.whenCompleteAsync((searchResponse, throwable) -> { + try { + if (throwable != null) { + LOG.error("Error while retrieving sample documents", throwable); + } else { + List<Hit<ObjectNode>> searchHits = searchResponse.hits().hits(); + this.sampled = searchHits != null ? searchHits.size() : 0; + if (sampled > 0) { + this.totalHits = searchResponse.hits().total().value(); + processAggregations(searchResponse.aggregations()); + searchResponse.hits().hits().forEach(this::processHit); + computeStatisticalFacets(); + } + } + } finally { + latch.countDown(); + } + }); } @Override - public void on(Hit<ObjectNode> searchHit) { - if (totalHits < sampleSize) { - super.on(searchHit); - } else { - if (sampleSize == sampled) { - return; + public List<FulltextIndex.Facet> getFacets(int numberOfFacets, String columnName) { + LOG.trace("Requested facets for {} - Latch count: {}", columnName, latch.getCount()); + try { + boolean completed = latch.await(15, TimeUnit.SECONDS); + if (!completed) { + throw new IllegalStateException("Timed out while waiting for facets"); } - int r = rGen.nextInt((int) (totalHits - seen)) + 1; - seen++; - - if (r <= sampleSize - sampled) { - sampled++; - final String path = elasticResponseHandler.getPath(searchHit); - if (path != null && isAccessible.test(path)) { - accessibleCount++; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); // restore interrupt status + throw new IllegalStateException("Error while waiting for facets", e); + } + LOG.trace("Reading facets for {} from {}", columnName, facets); + return facets != null ? facets.get(FulltextIndex.parseFacetField(columnName)) : null; + } + + private void processHit(Hit<ObjectNode> searchHit) { + final String path = elasticResponseHandler.getPath(searchHit); + if (path != null && isAccessible.test(path)) { + for (String field : facetFields) { + JsonNode value = searchHit.source().get(field); + if (value != null) { + accessibleFacetCounts.compute(field, (column, facetValues) -> { + if (facetValues == null) { + Map<String, Integer> values = new HashMap<>(); + values.put(value.asText(), 1); + return values; + } else { + facetValues.merge(value.asText(), 1, Integer::sum); + return facetValues; + } + }); } } } } - @Override - public void on(Map<String, Aggregate> aggregations) { + private void processAggregations(Map<String, Aggregate> aggregations) { for (String field : facetFields) { List<StringTermsBucket> buckets = aggregations.get(field).sterms().buckets().array(); - facetMap.put(field, buckets.stream() + allFacets.put(field, buckets.stream() .map(b -> new FulltextIndex.Facet(b.key().stringValue(), (int) b.docCount())) .collect(Collectors.toList()) ); } } - @Override - public void endData() { - if (totalHits < sampleSize) { - super.endData(); - } else { - for (String facet: facetMap.keySet()) { - facetMap.compute(facet, (s, facets1) -> updateLabelAndValueIfRequired(facets1)); - } - latch.countDown(); - } - } - - @Override - public List<FulltextIndex.Facet> getFacets(int numberOfFacets, String columnName) { - if (totalHits < sampleSize) { - return super.getFacets(numberOfFacets, columnName); - } else { - LOG.trace("Requested facets for {} - Latch count: {}", columnName, latch.getCount()); - try { - latch.await(15, TimeUnit.SECONDS); - } catch (InterruptedException e) { - throw new IllegalStateException("Error while waiting for facets", e); - } - LOG.trace("Reading facets for {} from {}", columnName, facetMap); - return facetMap.get(FulltextIndex.parseFacetField(columnName)); - } - } - - private List<FulltextIndex.Facet> updateLabelAndValueIfRequired(List<FulltextIndex.Facet> labelAndValues) { - if (accessibleCount < sampleSize) { - int numZeros = 0; - List<FulltextIndex.Facet> newValues; - { - List<FulltextIndex.Facet> proportionedLVs = new LinkedList<>(); - for (FulltextIndex.Facet labelAndValue : labelAndValues) { - long count = labelAndValue.getCount() * accessibleCount / sampleSize; - if (count == 0) { - numZeros++; + private void computeStatisticalFacets() { + for (String facetKey : allFacets.keySet()) { + if (accessibleFacetCounts.containsKey(facetKey)) { + Map<String, Integer> accessibleFacet = accessibleFacetCounts.get(facetKey); + List<FulltextIndex.Facet> uncheckedFacet = allFacets.get(facetKey); + for (FulltextIndex.Facet facet : uncheckedFacet) { + if (accessibleFacet.containsKey(facet.getLabel())) { + double sampleProportion = (double) accessibleFacet.get(facet.getLabel()) / sampled; + // returned count is the minimum between the accessible count and the count computed from the sample + accessibleFacet.put(facet.getLabel(), Math.min(facet.getCount(), (int) (sampleProportion * totalHits))); } - proportionedLVs.add(new FulltextIndex.Facet(labelAndValue.getLabel(), Math.toIntExact(count))); } - labelAndValues = proportionedLVs; } - if (numZeros > 0) { - newValues = new LinkedList<>(); - for (FulltextIndex.Facet lv : labelAndValues) { - if (lv.getCount() > 0) { - newValues.add(lv); - } - } - } else { - newValues = labelAndValues; - } - return newValues; - } else { - return labelAndValues; } + // create Facet objects, order by count (desc) and then by label (asc) + facets = accessibleFacetCounts.entrySet() + .stream() + .collect(Collectors.toMap + (Map.Entry::getKey, x -> x.getValue().entrySet() + .stream() + .map(e -> new FulltextIndex.Facet(e.getKey(), e.getValue())) + .sorted((f1, f2) -> { + int f1Count = f1.getCount(); + int f2Count = f2.getCount(); + if (f1Count == f2Count) { + return f1.getLabel().compareTo(f2.getLabel()); + } else return f2Count - f1Count; + }) + .collect(Collectors.toList()) + ) + ); + LOG.trace("Statistical facets {}", facets); Review Comment: no need -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: dev-unsubscr...@jackrabbit.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org