EBernhardson has uploaded a new change for review. ( 
https://gerrit.wikimedia.org/r/370776 )

Change subject: Abstract routing out of token count router
......................................................................

Abstract routing out of token count router

In preparation for adding a new kind of query router abstract most of
the routing implementation out of the token count router so it can be
reused.

Change-Id: If593b7250472b7f5d849ca307a5d34e9126504f5
---
A 
src/main/java/org/wikimedia/search/extra/tokencount/AbstractRouterQueryBuilder.java
M 
src/main/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryBuilder.java
M src/test/java/org/wikimedia/search/extra/QueryBuilderTestUtils.java
M src/test/java/org/wikimedia/search/extra/regex/SourceRegexBuilderESTest.java
M 
src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterBuilderESTest.java
M 
src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterParserTest.java
M 
src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryIntegrationTest.java
7 files changed, 342 insertions(+), 225 deletions(-)


  git pull ssh://gerrit.wikimedia.org:29418/search/extra 
refs/changes/76/370776/1

diff --git 
a/src/main/java/org/wikimedia/search/extra/tokencount/AbstractRouterQueryBuilder.java
 
b/src/main/java/org/wikimedia/search/extra/tokencount/AbstractRouterQueryBuilder.java
new file mode 100644
index 0000000..ce5520c
--- /dev/null
+++ 
b/src/main/java/org/wikimedia/search/extra/tokencount/AbstractRouterQueryBuilder.java
@@ -0,0 +1,299 @@
+package org.wikimedia.search.extra.tokencount;
+
+import lombok.AccessLevel;
+import lombok.EqualsAndHashCode;
+import lombok.Getter;
+import lombok.Setter;
+import lombok.experimental.Accessors;
+import org.apache.lucene.search.Query;
+import org.elasticsearch.common.ParseField;
+import org.elasticsearch.common.ParsingException;
+import org.elasticsearch.common.io.stream.StreamInput;
+import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.common.io.stream.Writeable;
+import org.elasticsearch.common.xcontent.*;
+import org.elasticsearch.index.query.*;
+import 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.Condition;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.function.Predicate;
+
+@Getter
+@Setter
+@Accessors(fluent = true, chain = true)
+abstract public class AbstractRouterQueryBuilder<C extends Condition, QB 
extends AbstractRouterQueryBuilder<C, QB>> extends AbstractQueryBuilder<QB> {
+    static final ParseField FALLBACK = new ParseField("fallback");
+    static final ParseField CONDITIONS = new ParseField("conditions");
+    static final ParseField QUERY = new ParseField("query");
+
+    @Getter(AccessLevel.PROTECTED)
+    private List<C> conditions;
+
+    private QueryBuilder fallback;
+
+    AbstractRouterQueryBuilder() {
+        this.conditions = new ArrayList<>();
+    }
+
+    AbstractRouterQueryBuilder(StreamInput in, Writeable.Reader<C> reader) 
throws IOException {
+        super(in);
+        conditions = in.readList(reader);
+        fallback = in.readNamedWriteable(QueryBuilder.class);
+    }
+
+    protected void doWriteTo(StreamOutput out) throws IOException {
+        out.writeList(conditions);
+        out.writeNamedWriteable(fallback);
+    }
+
+    QueryBuilder doRewrite(Predicate<C> condition) throws IOException {
+        QueryBuilder qb = conditions.stream()
+                .filter(condition)
+                .findFirst()
+                .map(Condition::query)
+                .orElse(fallback);
+
+        if (boost() != DEFAULT_BOOST || queryName() != null) {
+            // AbstractQueryBuilder#rewrite will copy non default boost/name
+            // to the rewritten query, we pass a fresh BoolQuery so we don't
+            // override the one on the rewritten query here
+            // Is this really useful?
+            return new BoolQueryBuilder().must(qb);
+        }
+        return qb;
+
+    }
+
+    @Override
+    protected boolean doEquals(QB other) {
+        AbstractRouterQueryBuilder<C, QB> qb = other;
+        return Objects.equals(fallback, qb.fallback) &&
+                Objects.equals(conditions, qb.conditions);
+    }
+
+    @Override
+    protected int doHashCode() {
+        return Objects.hash(fallback, conditions);
+    }
+
+    @Override
+    protected Query doToQuery(QueryShardContext queryShardContext) throws 
IOException {
+        throw new UnsupportedOperationException("This query must be 
rewritten.");
+    }
+
+    protected void addXContent(XContentBuilder builder, Params params) throws 
IOException {
+    }
+
+    @Override
+    protected void doXContent(XContentBuilder builder, Params params) throws 
IOException {
+        builder.startObject(getWriteableName());
+        if (fallback() != null) {
+            builder.field(FALLBACK.getPreferredName(), fallback());
+        }
+        if (!conditions().isEmpty()) {
+            builder.startArray(CONDITIONS.getPreferredName());
+            for (C c : conditions()) {
+                c.doXContent(builder, params);
+            }
+            builder.endArray();
+        }
+
+        addXContent(builder, params);
+        this.printBoostAndQueryName(builder);
+        builder.endObject();
+    }
+
+    static <C extends Condition, CPS extends ConditionParserState<C>> C 
parseCondition(
+            ObjectParser<CPS, QueryParseContext> condParser, XContentParser 
parser, QueryParseContext parseContext
+    ) throws IOException {
+        CPS state = condParser.parse(parser, parseContext);
+        String error = state.checkValid();
+        if (error != null) {
+            throw new ParsingException(parser.getTokenLocation(), error);
+        }
+        return state.condition();
+    }
+
+
+    @SuppressWarnings("unchecked")
+    static <QB extends AbstractRouterQueryBuilder<?, QB>> Optional<QB> 
fromXContent(
+            AbstractObjectParser<QB, QueryParseContext> objectParser, 
QueryParseContext parseContext) throws IOException {
+        XContentParser parser = parseContext.parser();
+        final QB builder;
+        try {
+            builder = objectParser.parse(parser, parseContext);
+        } catch (IllegalArgumentException iae) {
+            throw new ParsingException(parser.getTokenLocation(), 
iae.getMessage());
+        }
+
+        if (builder.conditions().isEmpty()) {
+            throw new ParsingException(parser.getTokenLocation(), "No 
conditions defined");
+        }
+
+        if (builder.fallback() == null) {
+            throw new ParsingException(parser.getTokenLocation(), "No fallback 
query defined");
+        }
+
+        return Optional.of(builder);
+    }
+
+    @Getter
+    @Accessors(fluent = true, chain = true)
+    @EqualsAndHashCode
+    public static class Condition implements Writeable {
+        private final ConditionDefinition definition;
+        private final int value;
+        private final QueryBuilder query;
+
+        Condition(StreamInput in) throws IOException {
+            definition = ConditionDefinition.readFrom(in);
+            value = in.readVInt();
+            query = in.readNamedWriteable(QueryBuilder.class);
+        }
+
+        Condition(ConditionDefinition defition, int value, QueryBuilder query) 
{
+            this.definition = Objects.requireNonNull(defition);
+            this.value = value;
+            this.query = Objects.requireNonNull(query);
+        }
+
+        public void writeTo(StreamOutput out) throws IOException {
+            definition.writeTo(out);
+            out.writeVInt(value);
+            out.writeNamedWriteable(query);
+        }
+
+        public boolean test(int lhs) {
+            return definition.test(lhs, value);
+        }
+
+        void addXContent(XContentBuilder builder, Params params) throws 
IOException {
+        }
+
+        void doXContent(XContentBuilder builder, Params params) throws 
IOException {
+            builder.startObject();
+            builder.field(definition.parseField.getPreferredName(), value);
+            builder.field(QUERY.getPreferredName(), query);
+            addXContent(builder, params);
+            builder.endObject();
+        }
+    }
+
+    static <C extends Condition, QB extends AbstractRouterQueryBuilder<C, QB>>
+    void declareRouterFields(AbstractObjectParser<QB, QueryParseContext> 
parser,
+                             ContextParser<QueryParseContext, C> objectParser) 
{
+        parser.declareObjectArray(QB::conditions, objectParser, CONDITIONS);
+        parser.declareObject(QB::fallback,
+                (p, ctx) -> ctx.parseInnerQueryBuilder()
+                        .orElseThrow(() -> new 
ParsingException(p.getTokenLocation(), "No fallback query defined")),
+                FALLBACK);
+    }
+
+    static <CPS extends ConditionParserState<?>>
+    void declareConditionFields(AbstractObjectParser<CPS, QueryParseContext> 
parser) {
+        for (ConditionDefinition def : ConditionDefinition.values()) {
+            // gt: int, addPredicate will fail if a predicate has already been 
set
+            parser.declareInt((cps, value) -> cps.addPredicate(def, value), 
def.parseField);
+        }
+        // query: { }
+        parser.declareObject(CPS::setQuery,
+                (p, ctx) -> ctx.parseInnerQueryBuilder()
+                        .orElseThrow(() -> new 
ParsingException(p.getTokenLocation(), "No query defined for condition")),
+                QUERY);
+    }
+
+    @FunctionalInterface
+    interface ConditionProvider<C extends Condition> {
+        C create(ConditionDefinition def, int value, QueryBuilder query);
+    }
+
+    static class ConditionParserState<C extends Condition> {
+        private ConditionProvider<C> provider;
+        private ConditionDefinition definition;
+        private int value;
+        protected QueryBuilder query;
+
+        ConditionParserState(ConditionProvider<C> provider) {
+            this.provider = provider;
+        }
+
+        void provider(ConditionProvider<C> provider) {
+            // Hax because extending classes cant pass a provider that
+            // references their own fields before calling the constructor
+            this.provider = provider;
+        }
+
+        void addPredicate(ConditionDefinition def, int value) {
+            if (this.definition != null) {
+                throw new IllegalArgumentException("Cannot set extra predicate 
[" + def.parseField + "] " +
+                        "on condition: [" + this.definition.parseField + "] 
already set");
+            }
+            this.definition = def;
+            this.value = value;
+        }
+
+        public void setQuery(QueryBuilder query) {
+            this.query = query;
+        }
+
+        C condition() {
+            Objects.requireNonNull(provider);
+            return provider.create(definition, value, query);
+        }
+
+        String checkValid() {
+            if (query == null) {
+                return "Missing field [query] in condition";
+            }
+            if (definition == null) {
+                return "Missing condition predicate in condition";
+            }
+            return null;
+        }
+    }
+
+    @FunctionalInterface
+    public interface BiIntPredicate {
+        boolean test(int a, int b);
+    }
+
+    public enum ConditionDefinition implements BiIntPredicate, Writeable {
+        eq ((a,b) -> a == b),
+        neq ((a,b) -> a != b),
+        lte ((a,b) -> a <= b),
+        lt ((a,b) -> a < b),
+        gte ((a,b) -> a >= b),
+        gt ((a,b) -> a > b);
+
+        final ParseField parseField;
+        final BiIntPredicate predicate;
+
+        ConditionDefinition(BiIntPredicate predicate) {
+            this.predicate = predicate;
+            this.parseField = new ParseField(name());
+        }
+
+        @Override
+        public boolean test(int a, int b) {
+            return predicate.test(a, b);
+        }
+
+        @Override
+        public void writeTo(StreamOutput out) throws IOException {
+            out.writeVInt(ordinal());
+        }
+
+        public static ConditionDefinition readFrom(StreamInput in) throws 
IOException {
+            int ord = in.readVInt();
+            if (ord < 0 || ord >= ConditionDefinition.values().length) {
+                throw new IOException("Unknown ConditionDefinition ordinal [" 
+ ord + "]");
+            }
+            return ConditionDefinition.values()[ord];
+        }
+
+    }
+}
diff --git 
a/src/main/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryBuilder.java
 
b/src/main/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryBuilder.java
index 60d2f3d..0a2a12d 100644
--- 
a/src/main/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryBuilder.java
+++ 
b/src/main/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryBuilder.java
@@ -1,36 +1,28 @@
 package org.wikimedia.search.extra.tokencount;
 
-import lombok.AccessLevel;
-import lombok.EqualsAndHashCode;
+import com.google.common.annotations.VisibleForTesting;
 import lombok.Getter;
 import lombok.Setter;
 import lombok.experimental.Accessors;
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.analysis.TokenStream;
 import org.apache.lucene.analysis.tokenattributes.PositionIncrementAttribute;
-import org.apache.lucene.search.Query;
 import org.elasticsearch.common.ParseField;
 import org.elasticsearch.common.ParsingException;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ObjectParser;
 import org.elasticsearch.common.xcontent.XContentBuilder;
 import org.elasticsearch.common.xcontent.XContentParser;
 import org.elasticsearch.index.mapper.MappedFieldType;
-import org.elasticsearch.index.query.AbstractQueryBuilder;
-import org.elasticsearch.index.query.BoolQueryBuilder;
 import org.elasticsearch.index.query.QueryBuilder;
 import org.elasticsearch.index.query.QueryParseContext;
 import org.elasticsearch.index.query.QueryRewriteContext;
-import org.elasticsearch.index.query.QueryShardContext;
+import 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.Condition;
 
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
-import java.util.function.IntPredicate;
 import java.util.stream.Stream;
 
 /**
@@ -39,44 +31,27 @@
 @Getter
 @Setter
 @Accessors(fluent = true, chain = true)
-public class TokenCountRouterQueryBuilder extends 
AbstractQueryBuilder<TokenCountRouterQueryBuilder> {
+public class TokenCountRouterQueryBuilder extends 
AbstractRouterQueryBuilder<Condition, TokenCountRouterQueryBuilder> {
     public static final ParseField NAME = new ParseField("token_count_router");
     static final ParseField TEXT = new ParseField("text");
     static final ParseField FIELD = new ParseField("field");
     static final ParseField ANALYZER = new ParseField("analyzer");
     static final ParseField DISCOUNT_OVERLAPS = new 
ParseField("discount_overlaps");
-    static final ParseField CONDITIONS = new ParseField("conditions");
-    static final ParseField FALLBACK = new ParseField("fallback");
-    static final ParseField QUERY = new ParseField("query");
     static final boolean DEFAULT_DISCOUNT_OVERLAPS = true;
 
     private final static ObjectParser<TokenCountRouterQueryBuilder, 
QueryParseContext> PARSER;
-    private final static ObjectParser<ConditionParserState, QueryParseContext> 
COND_PARSER;
+    private final static ObjectParser<ConditionParserState<Condition>, 
QueryParseContext> COND_PARSER;
 
     static {
-
-        COND_PARSER = new ObjectParser<>("condition", 
ConditionParserState::new);
-        for (ConditionDefinition def : ConditionDefinition.values()) {
-            // gt: int, addPredicate will fail if a predicate has already been 
set
-            COND_PARSER.declareInt((cps, value) -> cps.addPredicate(def, 
value), def.parseField);
-        }
-        // query: { }
-        COND_PARSER.declareObject(ConditionParserState::setQuery,
-                (p, ctx) -> ctx.parseInnerQueryBuilder()
-                        .orElseThrow(() -> new 
ParsingException(p.getTokenLocation(), "No query defined for condition")),
-                QUERY);
+        COND_PARSER = new ObjectParser<>("condition", () -> new 
ConditionParserState<>(Condition::new));
+        declareConditionFields(COND_PARSER);
 
         PARSER = new ObjectParser<>(NAME.getPreferredName(), 
TokenCountRouterQueryBuilder::new);
         PARSER.declareString(TokenCountRouterQueryBuilder::text, TEXT);
         PARSER.declareString(TokenCountRouterQueryBuilder::field, FIELD);
         PARSER.declareString(TokenCountRouterQueryBuilder::analyzer, ANALYZER);
         PARSER.declareBoolean(TokenCountRouterQueryBuilder::discountOverlaps, 
DISCOUNT_OVERLAPS);
-        PARSER.declareObjectArray(TokenCountRouterQueryBuilder::conditions,
-                TokenCountRouterQueryBuilder::parseCondition, CONDITIONS);
-        PARSER.declareObject(TokenCountRouterQueryBuilder::fallback,
-                (p, ctx) -> ctx.parseInnerQueryBuilder()
-                        .orElseThrow(() -> new 
ParsingException(p.getTokenLocation(), "No fallback query defined")),
-                FALLBACK);
+        declareRouterFields(PARSER, (p, pc) -> parseCondition(COND_PARSER, p, 
pc));
         declareStandardFields(PARSER);
     }
 
@@ -85,45 +60,26 @@
     private String field;
     private boolean discountOverlaps = DEFAULT_DISCOUNT_OVERLAPS;
     private String text;
-    private QueryBuilder fallback;
-
-    @Getter(AccessLevel.PRIVATE)
-    private List<Condition> conditions;
 
     public TokenCountRouterQueryBuilder() {
-        this.conditions = new ArrayList<>();
+        super();
     }
 
     public TokenCountRouterQueryBuilder(StreamInput in) throws IOException {
-        super(in);
-
+        super(in, Condition::new);
         analyzer = in.readOptionalString();
         field = in.readOptionalString();
         discountOverlaps = in.readBoolean();
         text = in.readString();
-        conditions = in.readList(Condition::new);
-        fallback = in.readNamedWriteable(QueryBuilder.class);
-    }
-
-    private static Condition parseCondition(XContentParser parser, 
QueryParseContext parseContext) throws IOException {
-        ConditionParserState state = COND_PARSER.parse(parser, parseContext);
-        if (state.query == null) {
-            throw new ParsingException(parser.getTokenLocation(), "Missing 
field [query] in condition");
-        }
-        if (state.definition == null) {
-            throw new ParsingException(parser.getTokenLocation(), "Missing 
condition predicate in condition");
-        }
-        return state.condition();
     }
 
     @Override
     protected void doWriteTo(StreamOutput out) throws IOException {
+        super.doWriteTo(out);
         out.writeOptionalString(analyzer);
         out.writeOptionalString(field);
         out.writeBoolean(discountOverlaps);
         out.writeString(text);
-        out.writeList(conditions);
-        out.writeNamedWriteable(fallback);
     }
 
     @Override
@@ -131,14 +87,8 @@
         return NAME.getPreferredName();
     }
 
-    public TokenCountRouterQueryBuilder condition(ConditionDefinition 
predicate, int value, QueryBuilder query) {
-        conditions.add(new Condition(predicate, value, query));
-        return this;
-    }
-
     @Override
-    protected void doXContent(XContentBuilder builder, Params params) throws 
IOException {
-        builder.startObject(NAME.getPreferredName());
+    protected void addXContent(XContentBuilder builder, Params params) throws 
IOException {
         if (analyzer != null) {
             builder.field(ANALYZER.getPreferredName(), analyzer);
         }
@@ -151,53 +101,19 @@
         if (text != null) {
             builder.field(TEXT.getPreferredName(), text);
         }
-        if (fallback != null) {
-            builder.field(FALLBACK.getPreferredName(), fallback);
-        }
-        if (!conditions.isEmpty()) {
-            builder.startArray(CONDITIONS.getPreferredName());
-            for (Condition c : conditions) {
-                builder.startObject();
-                builder.field(c.definition.parseField.getPreferredName(), 
c.value);
-                builder.field(QUERY.getPreferredName(), c.query);
-                builder.endObject();
-            }
-            builder.endArray();
-        }
-        this.printBoostAndQueryName(builder);
-        builder.endObject();
     }
 
     public static Optional<TokenCountRouterQueryBuilder> 
fromXContent(QueryParseContext parseContext) throws IOException {
+        final Optional<TokenCountRouterQueryBuilder> builder = 
AbstractRouterQueryBuilder.fromXContent(PARSER, parseContext);
+
         XContentParser parser = parseContext.parser();
-        final TokenCountRouterQueryBuilder builder;
-        try {
-            builder = PARSER.parse(parser, parseContext);
-        } catch (IllegalArgumentException iae) {
-            throw new ParsingException(parser.getTokenLocation(), 
iae.getMessage());
-        }
+        builder.filter(b -> b.text != null)
+                .orElseThrow(() -> new 
ParsingException(parser.getTokenLocation(), "No text provided"));
 
-        if (builder.conditions.isEmpty()) {
-            throw new ParsingException(parser.getTokenLocation(), "No 
conditions defined");
-        }
+        builder.filter(b -> b.analyzer != null || b.field != null)
+                .orElseThrow(() -> new 
ParsingException(parser.getTokenLocation(), "Missing field or analyzer 
definition"));
 
-        if (builder.text() == null) {
-            throw new ParsingException(parser.getTokenLocation(), "No text 
provided");
-        }
-
-        if (builder.fallback() == null) {
-            throw new ParsingException(parser.getTokenLocation(), "No fallback 
query defined");
-        }
-
-        if (builder.analyzer() == null && builder.field() == null) {
-            throw new ParsingException(parser.getTokenLocation(), "Missing 
field or analyzer definition");
-        }
-        return Optional.of(builder);
-    }
-
-    @Override
-    protected Query doToQuery(QueryShardContext queryShardContext) throws 
IOException {
-        throw new UnsupportedOperationException("This query must be 
rewritten.");
+        return builder;
     }
 
     @Override
@@ -224,21 +140,9 @@
         if (text == null) {
             throw new IllegalArgumentException("text cannot be null");
         }
-        final int count = countToken(luceneAnalyzer, text, discountOverlaps);
-        QueryBuilder qb = conditions.stream()
-                .filter(c -> c.test(count))
-                .findFirst()
-                .map(Condition::query)
-                .orElse(fallback);
 
-        if (boost() != DEFAULT_BOOST || queryName() != null) {
-            // AbstractQueryBuilder#rewrite will copy non default boost/name
-            // to the rewritten query, we pass a fresh BoolQuery so we don't
-            // override the one on the rewritten query here
-            // Is this really useful?
-            return new BoolQueryBuilder().must(qb);
-        }
-        return qb;
+        final int count = countToken(luceneAnalyzer, text, discountOverlaps);
+        return super.doRewrite((c) -> c.test(count));
     }
 
     static int countToken(Analyzer analyzer, String text, boolean 
discountOverlaps) throws IOException {
@@ -255,115 +159,29 @@
         }
     }
 
+    @VisibleForTesting
     Stream<Condition> conditionStream() {
-        return conditions.stream();
+        return conditions().stream();
     }
 
     @Override
     protected boolean doEquals(TokenCountRouterQueryBuilder other) {
-        return Objects.equals(text, other.text) &&
+        return super.doEquals(other) &&
+                Objects.equals(text, other.text) &&
                 Objects.equals(field, other.field) &&
                 Objects.equals(analyzer, other.analyzer) &&
-                Objects.equals(discountOverlaps, other.discountOverlaps) &&
-                Objects.equals(fallback, other.fallback) &&
-                Objects.equals(conditions, other.conditions);
+                Objects.equals(discountOverlaps, other.discountOverlaps);
     }
 
     @Override
     protected int doHashCode() {
-        return Objects.hash(text, field, analyzer, discountOverlaps, fallback, 
conditions);
+        // TODO: Is calling super.doHashCode the right thing here?
+        return Objects.hash(text, field, analyzer, discountOverlaps, 
super.doHashCode());
     }
 
-    @EqualsAndHashCode
-    @Getter
-    static class Condition implements Writeable, IntPredicate {
-        private final ConditionDefinition definition;
-        private final int value;
-        private final QueryBuilder query;
-
-        Condition(StreamInput in) throws IOException {
-            this.definition = ConditionDefinition.readFrom(in);
-            value = in.readVInt();
-            query = in.readNamedWriteable(QueryBuilder.class);
-        }
-
-        Condition(ConditionDefinition defition, int value, QueryBuilder query) 
{
-            this.definition = Objects.requireNonNull(defition);
-            this.value = value;
-            this.query = Objects.requireNonNull(query);
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            definition.writeTo(out);
-            out.writeVInt(value);
-            out.writeNamedWriteable(query);
-        }
-
-        @Override
-        public boolean test(int tokenCount) {
-            return definition.test(tokenCount, value);
-        }
-    }
-
-    @FunctionalInterface
-    public interface BiIntPredicate {
-        boolean test(int a, int b);
-    }
-
-    public enum ConditionDefinition implements BiIntPredicate, Writeable {
-        eq ((a,b) -> a == b),
-        neq ((a,b) -> a != b),
-        lte ((a,b) -> a <= b),
-        lt ((a,b) -> a < b),
-        gte ((a,b) -> a >= b),
-        gt ((a,b) -> a > b);
-
-        private final ParseField parseField;
-        private final BiIntPredicate predicate;
-
-        ConditionDefinition(BiIntPredicate predicate) {
-            this.predicate = predicate;
-            this.parseField = new ParseField(name());
-        }
-
-        @Override
-        public boolean test(int a, int b) {
-            return predicate.test(a, b);
-        }
-
-        @Override
-        public void writeTo(StreamOutput out) throws IOException {
-            out.writeVInt(ordinal());
-        }
-
-        public static ConditionDefinition readFrom(StreamInput in) throws 
IOException {
-            int ord = in.readVInt();
-            if (ord < 0 || ord >= values().length) {
-                throw new IOException("Unknown ConditionDefinition ordinal [" 
+ ord + "]");
-            }
-            return values()[ord];
-        }
-    }
-
-    private static class ConditionParserState {
-        private ConditionDefinition definition;
-        private int value;
-        private QueryBuilder query;
-
-        public void addPredicate(ConditionDefinition def, int value) {
-            if (this.definition != null) {
-                throw new IllegalArgumentException("Cannot set extra predicate 
[" + def.parseField + "] " +
-                        "on condition: [" + this.definition.parseField + "] 
already set");
-            }
-            this.definition = def;
-            this.value = value;
-        }
-        public void setQuery(QueryBuilder query) {
-            this.query = query;
-        }
-        public Condition condition() {
-            return new Condition(definition, value, query);
-        }
+    @VisibleForTesting
+    public TokenCountRouterQueryBuilder condition(ConditionDefinition def, int 
value, QueryBuilder qb) {
+        conditions().add(new Condition(def, value, qb));
+        return this;
     }
 }
diff --git 
a/src/test/java/org/wikimedia/search/extra/QueryBuilderTestUtils.java 
b/src/test/java/org/wikimedia/search/extra/QueryBuilderTestUtils.java
index a245944..d60a3d5 100644
--- a/src/test/java/org/wikimedia/search/extra/QueryBuilderTestUtils.java
+++ b/src/test/java/org/wikimedia/search/extra/QueryBuilderTestUtils.java
@@ -10,6 +10,7 @@
 import org.elasticsearch.search.SearchModule;
 
 import java.io.IOException;
+import java.nio.file.Paths;
 import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
@@ -24,7 +25,7 @@
     private final NamedXContentRegistry xContentRegistry;
 
     private QueryBuilderTestUtils() {
-        SearchModule module = new SearchModule(Settings.EMPTY, false, 
Collections.singletonList(new ExtraPlugin()));
+        SearchModule module = new SearchModule(Settings.EMPTY, false, 
Collections.singletonList(new ExtraPlugin(Settings.EMPTY)));
         xContentRegistry = new 
NamedXContentRegistry(module.getNamedXContents());
     }
 
diff --git 
a/src/test/java/org/wikimedia/search/extra/regex/SourceRegexBuilderESTest.java 
b/src/test/java/org/wikimedia/search/extra/regex/SourceRegexBuilderESTest.java
index f658f3f..4b84523 100644
--- 
a/src/test/java/org/wikimedia/search/extra/regex/SourceRegexBuilderESTest.java
+++ 
b/src/test/java/org/wikimedia/search/extra/regex/SourceRegexBuilderESTest.java
@@ -23,7 +23,7 @@
 
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.instanceOf;
-import static 
org.wikimedia.search.extra.tokencount.TokenCountRouterQueryBuilder.ConditionDefinition.gte;
+import static 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.ConditionDefinition.gte;
 
 public class SourceRegexBuilderESTest extends 
AbstractQueryTestCase<SourceRegexQueryBuilder> {
     protected Collection<Class<? extends Plugin>> getPlugins() {
diff --git 
a/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterBuilderESTest.java
 
b/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterBuilderESTest.java
index 93ab047..9cede5c 100644
--- 
a/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterBuilderESTest.java
+++ 
b/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterBuilderESTest.java
@@ -22,7 +22,7 @@
 import org.elasticsearch.search.internal.SearchContext;
 import org.elasticsearch.test.AbstractQueryTestCase;
 import org.wikimedia.search.extra.ExtraPlugin;
-import 
org.wikimedia.search.extra.tokencount.TokenCountRouterQueryBuilder.Condition;
+import 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.Condition;
 
 import java.io.IOException;
 import java.util.Collection;
@@ -32,8 +32,8 @@
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.hamcrest.CoreMatchers.equalTo;
 import static org.hamcrest.CoreMatchers.instanceOf;
-import static 
org.wikimedia.search.extra.tokencount.TokenCountRouterQueryBuilder.ConditionDefinition.gt;
-import static 
org.wikimedia.search.extra.tokencount.TokenCountRouterQueryBuilder.ConditionDefinition.gte;
+import static 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.ConditionDefinition.gt;
+import static 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.ConditionDefinition.gte;
 
 public class TokenCountRouterBuilderESTest extends 
AbstractQueryTestCase<TokenCountRouterQueryBuilder> {
     protected Collection<Class<? extends Plugin>> getPlugins() {
@@ -62,7 +62,7 @@
         }
 
         for(int i = randomIntBetween(1,10); i > 0; i--) {
-            TokenCountRouterQueryBuilder.ConditionDefinition cond = 
randomFrom(TokenCountRouterQueryBuilder.ConditionDefinition.values());
+            AbstractRouterQueryBuilder.ConditionDefinition cond = 
randomFrom(AbstractRouterQueryBuilder.ConditionDefinition.values());
             int value = randomInt(10);
             builder.condition(cond, value, new TermQueryBuilder(cond.name(), 
String.valueOf(value)));
         }
@@ -199,9 +199,8 @@
         builder.analyzer(randomAnalyzer());
         QueryBuilder toRewrite = new TermQueryBuilder("fallback", "fallback");
         builder.fallback(new WrapperQueryBuilder(toRewrite.toString()));
-        int nbCond = randomInt(10);
         for(int i = randomIntBetween(1,10); i > 0; i--) {
-            TokenCountRouterQueryBuilder.ConditionDefinition cond = 
randomFrom(TokenCountRouterQueryBuilder.ConditionDefinition.values());
+            AbstractRouterQueryBuilder.ConditionDefinition cond = 
randomFrom(AbstractRouterQueryBuilder.ConditionDefinition.values());
             int value = randomInt(10);
             builder.condition(cond, value, new 
WrapperQueryBuilder(toRewrite.toString()));
         }
diff --git 
a/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterParserTest.java
 
b/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterParserTest.java
index 05fecae..68dc63b 100644
--- 
a/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterParserTest.java
+++ 
b/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterParserTest.java
@@ -7,13 +7,13 @@
 import org.elasticsearch.index.query.QueryBuilders;
 import org.junit.Test;
 import org.wikimedia.search.extra.QueryBuilderTestUtils;
-import 
org.wikimedia.search.extra.tokencount.TokenCountRouterQueryBuilder.Condition;
+import 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.Condition;
 
 import java.io.IOException;
 import java.util.Optional;
 
 import static org.hamcrest.CoreMatchers.instanceOf;
-import static 
org.wikimedia.search.extra.tokencount.TokenCountRouterQueryBuilder.ConditionDefinition.gte;
+import static 
org.wikimedia.search.extra.tokencount.AbstractRouterQueryBuilder.ConditionDefinition.gte;
 
 public class TokenCountRouterParserTest extends LuceneTestCase {
     @Test
diff --git 
a/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryIntegrationTest.java
 
b/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryIntegrationTest.java
index 297fc24..2b226fd 100644
--- 
a/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryIntegrationTest.java
+++ 
b/src/test/java/org/wikimedia/search/extra/tokencount/TokenCountRouterQueryIntegrationTest.java
@@ -48,8 +48,8 @@
         init();
         TokenCountRouterQueryBuilder builder = new 
TokenCountRouterQueryBuilder();
         builder.field("content")
-                
.condition(TokenCountRouterQueryBuilder.ConditionDefinition.gt, 4, 
QueryBuilders.termQuery("content", "absent"))
-                
.condition(TokenCountRouterQueryBuilder.ConditionDefinition.gte, 2, 
QueryBuilders.termQuery("content", "haste"))
+                .condition(AbstractRouterQueryBuilder.ConditionDefinition.gt, 
4, QueryBuilders.termQuery("content", "absent"))
+                .condition(AbstractRouterQueryBuilder.ConditionDefinition.gte, 
2, QueryBuilders.termQuery("content", "haste"))
                 .fallback(QueryBuilders.termQuery("content", "strength"));
         SearchResponse sr;
         builder.text("one and two and three");

-- 
To view, visit https://gerrit.wikimedia.org/r/370776
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: newchange
Gerrit-Change-Id: If593b7250472b7f5d849ca307a5d34e9126504f5
Gerrit-PatchSet: 1
Gerrit-Project: search/extra
Gerrit-Branch: master
Gerrit-Owner: EBernhardson <ebernhard...@wikimedia.org>

_______________________________________________
MediaWiki-commits mailing list
MediaWiki-commits@lists.wikimedia.org
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to