Github user kchilton2 commented on a diff in the pull request:

    https://github.com/apache/incubator-rya/pull/254#discussion_r159335982
  
    --- Diff: 
dao/mongodb.rya/src/main/java/org/apache/rya/mongodb/aggregation/AggregationPipelineQueryNode.java
 ---
    @@ -0,0 +1,882 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package org.apache.rya.mongodb.aggregation;
    +
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.CONTEXT;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.DOCUMENT_VISIBILITY;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.OBJECT;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.OBJECT_HASH;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.OBJECT_TYPE;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.PREDICATE;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.PREDICATE_HASH;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.STATEMENT_METADATA;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.SUBJECT;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.SUBJECT_HASH;
    +import static 
org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy.TIMESTAMP;
    +
    +import java.util.Arrays;
    +import java.util.HashMap;
    +import java.util.HashSet;
    +import java.util.LinkedList;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.NavigableSet;
    +import java.util.Set;
    +import java.util.UUID;
    +import java.util.concurrent.ConcurrentSkipListSet;
    +import java.util.function.Function;
    +
    +import org.apache.rya.api.domain.RyaStatement;
    +import org.apache.rya.api.domain.RyaType;
    +import org.apache.rya.api.domain.RyaURI;
    +import org.apache.rya.api.domain.StatementMetadata;
    +import org.apache.rya.api.resolver.RdfToRyaConversions;
    +import org.apache.rya.mongodb.MongoDbRdfConstants;
    +import org.apache.rya.mongodb.dao.MongoDBStorageStrategy;
    +import org.apache.rya.mongodb.dao.SimpleMongoDBStorageStrategy;
    +import 
org.apache.rya.mongodb.document.operators.query.ConditionalOperators;
    +import 
org.apache.rya.mongodb.document.visibility.DocumentVisibilityAdapter;
    +import org.bson.Document;
    +import org.bson.conversions.Bson;
    +import org.openrdf.model.Literal;
    +import org.openrdf.model.Resource;
    +import org.openrdf.model.URI;
    +import org.openrdf.model.Value;
    +import org.openrdf.model.vocabulary.XMLSchema;
    +import org.openrdf.query.BindingSet;
    +import org.openrdf.query.QueryEvaluationException;
    +import org.openrdf.query.algebra.Compare;
    +import org.openrdf.query.algebra.ExtensionElem;
    +import org.openrdf.query.algebra.ProjectionElem;
    +import org.openrdf.query.algebra.ProjectionElemList;
    +import org.openrdf.query.algebra.StatementPattern;
    +import org.openrdf.query.algebra.ValueConstant;
    +import org.openrdf.query.algebra.ValueExpr;
    +import org.openrdf.query.algebra.Var;
    +import org.openrdf.query.algebra.evaluation.impl.ExternalSet;
    +
    +import com.google.common.base.Preconditions;
    +import com.google.common.collect.BiMap;
    +import com.google.common.collect.HashBiMap;
    +import com.mongodb.BasicDBObject;
    +import com.mongodb.DBObject;
    +import com.mongodb.client.MongoCollection;
    +import com.mongodb.client.model.Aggregates;
    +import com.mongodb.client.model.BsonField;
    +import com.mongodb.client.model.Filters;
    +import com.mongodb.client.model.Projections;
    +
    +import info.aduna.iteration.CloseableIteration;
    +
    +/**
    + * Represents a portion of a query tree as MongoDB aggregation pipeline. 
Should
    + * be built bottom-up: start with a statement pattern implemented as a 
$match
    + * step, then add steps to the pipeline to handle higher levels of the 
query
    + * tree. Methods are provided to add certain supported query operations to 
the
    + * end of the internal pipeline. In some cases, specific arguments may be
    + * unsupported, in which case the pipeline is unchanged and the method 
returns
    + * false.
    + */
    +public class AggregationPipelineQueryNode extends ExternalSet {
    +    /**
    +     * An aggregation result corresponding to a solution should map this 
key
    +     * to an object which itself maps variable names to variable values.
    +     */
    +    static final String VALUES = "<VALUES>";
    +
    +    /**
    +     * An aggregation result corresponding to a solution should map this 
key
    +     * to an object which itself maps variable names to the corresponding 
hashes
    +     * of their values.
    +     */
    +    static final String HASHES = "<HASHES>";
    +
    +    /**
    +     * An aggregation result corresponding to a solution should map this 
key
    +     * to an object which itself maps variable names to their datatypes, 
if any.
    +     */
    +    static final String TYPES = "<TYPES>";
    +
    +    private static final String LEVEL = "derivation_level";
    +    private static final String[] FIELDS = { VALUES, HASHES, TYPES, LEVEL, 
TIMESTAMP };
    +
    +    private static final String JOINED_TRIPLE = "<JOINED_TRIPLE>";
    +    private static final String FIELDS_MATCH = "<JOIN_FIELDS_MATCH>";
    +
    +    private static final MongoDBStorageStrategy<RyaStatement> strategy = 
new SimpleMongoDBStorageStrategy();
    +
    +    private static final Bson DEFAULT_TYPE = new Document("$literal", 
XMLSchema.ANYURI.stringValue());
    +    private static final Bson DEFAULT_CONTEXT = new Document("$literal", 
"");
    +    private static final Bson DEFAULT_DV = 
DocumentVisibilityAdapter.toDBObject(MongoDbRdfConstants.EMPTY_DV);
    +    private static final Bson DEFAULT_METADATA = new Document("$literal",
    +            StatementMetadata.EMPTY_METADATA.toString());
    +
    +    private static boolean isValidFieldName(String name) {
    +        return !(name == null || name.contains(".") || name.contains("$")
    +                || name.equals("_id"));
    +    }
    +
    +    /**
    +     * For a given statement pattern, represents a mapping from query 
variables
    +     * to their corresponding parts of matching triples. If necessary, also
    +     * substitute variable names including invalid characters with 
temporary
    +     * replacements, while producing a map back to the original names.
    +     */
    +    private static class StatementVarMapping {
    +        private final Map<String, String> varToTripleValue = new 
HashMap<>();
    +        private final Map<String, String> varToTripleHash = new 
HashMap<>();
    +        private final Map<String, String> varToTripleType = new 
HashMap<>();
    +        private final BiMap<String, String> varToOriginalName;
    +
    +        String valueField(String varName) {
    +            return varToTripleValue.get(varName);
    +        }
    +        String hashField(String varName) {
    +            return varToTripleHash.get(varName);
    +        }
    +        String typeField(String varName) {
    +            return varToTripleType.get(varName);
    +        }
    +
    +        Set<String> varNames() {
    +            return varToTripleValue.keySet();
    +        }
    +
    +        private String replace(String original) {
    +            if (varToOriginalName.containsValue(original)) {
    +                return varToOriginalName.inverse().get(original);
    +            }
    +            else {
    +                String replacement = "field-" + UUID.randomUUID();
    +                varToOriginalName.put(replacement, original);
    +                return replacement;
    +            }
    +        }
    +
    +        private String sanitize(String name) {
    +            if (varToOriginalName.containsValue(name)) {
    +                return varToOriginalName.inverse().get(name);
    +            }
    +            else if (name != null && !isValidFieldName(name)) {
    +                return replace(name);
    +            }
    +            return name;
    +        }
    +
    +        StatementVarMapping(StatementPattern sp, BiMap<String, String> 
varToOriginalName) {
    +            this.varToOriginalName = varToOriginalName;
    +            if (sp.getSubjectVar() != null && 
!sp.getSubjectVar().hasValue()) {
    +                String name = sanitize(sp.getSubjectVar().getName());
    +                varToTripleValue.put(name, SUBJECT);
    +                varToTripleHash.put(name, SUBJECT_HASH);
    +            }
    +            if (sp.getPredicateVar() != null && 
!sp.getPredicateVar().hasValue()) {
    +                String name = sanitize(sp.getPredicateVar().getName());
    +                varToTripleValue.put(name, PREDICATE);
    +                varToTripleHash.put(name, PREDICATE_HASH);
    +            }
    +            if (sp.getObjectVar() != null && 
!sp.getObjectVar().hasValue()) {
    +                String name = sanitize(sp.getObjectVar().getName());
    +                varToTripleValue.put(name, OBJECT);
    +                varToTripleHash.put(name, OBJECT_HASH);
    +                varToTripleType.put(name, OBJECT_TYPE);
    +            }
    +            if (sp.getContextVar() != null && 
!sp.getContextVar().hasValue()) {
    +                String name = sanitize(sp.getContextVar().getName());
    +                varToTripleValue.put(name, CONTEXT);
    +            }
    +        }
    +
    +        Bson getProjectExpression() {
    +            return getProjectExpression(new LinkedList<>(), str -> "$" + 
str);
    +        }
    +
    +        Bson getProjectExpression(Iterable<String> alsoInclude,
    +                Function<String, String> getFieldExpr) {
    +            Document values = new Document();
    +            Document hashes = new Document();
    +            Document types = new Document();
    +            for (String varName : varNames()) {
    +                values.append(varName, 
getFieldExpr.apply(valueField(varName)));
    +                if (varToTripleHash.containsKey(varName)) {
    +                    hashes.append(varName, 
getFieldExpr.apply(hashField(varName)));
    +                }
    +                if (varToTripleType.containsKey(varName)) {
    +                    types.append(varName, 
getFieldExpr.apply(typeField(varName)));
    +                }
    +            }
    +            for (String varName : alsoInclude) {
    +                values.append(varName, 1);
    +                hashes.append(varName, 1);
    +                types.append(varName, 1);
    +            }
    +            List<Bson> fields = new LinkedList<>();
    +            fields.add(Projections.excludeId());
    +            fields.add(Projections.computed(VALUES, values));
    +            fields.add(Projections.computed(HASHES, hashes));
    +            if (!types.isEmpty()) {
    +                fields.add(Projections.computed(TYPES, types));
    +            }
    +            fields.add(Projections.computed(LEVEL,
    +                    maxValueExpr("$" + LEVEL, getFieldExpr.apply(LEVEL), 
0)));
    +            fields.add(Projections.computed(TIMESTAMP,
    +                    maxValueExpr("$" + TIMESTAMP, 
getFieldExpr.apply(TIMESTAMP), 0)));
    +            return Projections.fields(fields);
    +        }
    +    }
    +
    +    /**
    +     * Generate a projection expression that evaluates to the maximum of 
two
    +     * fields and a default value.
    +     */
    +    private static Document maxValueExpr(String field1, String field2, 
Object defaultValue) {
    +        if (field1.equals(field2)) {
    +            return ConditionalOperators.ifNull(field1, defaultValue);
    +        }
    +        else {
    +            Document vars = new Document("x", 
ConditionalOperators.ifNull(field1, defaultValue))
    +                    .append("y", ConditionalOperators.ifNull(field2, 
defaultValue));
    +            Document gt = new Document("$gt", Arrays.asList("$$x", "$$y"));
    +            Document maxExpr = new Document("$cond",
    +                    new Document("if", gt).append("then", 
"$$x").append("else", "$$y"));
    +            return new Document("$let", new Document("vars", 
vars).append("in", maxExpr));
    +        }
    +    }
    +
    +    /**
    +     * Given a StatementPattern, generate an object representing the 
arguments
    +     * to a "$match" command that will find matching triples.
    +     * @param sp The StatementPattern to search for
    +     * @param path If given, specify the field that should be matched 
against
    +     *  the statement pattern, using an ordered list of field names for a 
nested
    +     *  field. E.g. to match records { "x": { "y": <statement pattern } }, 
pass
    +     *  "x" followed by "y".
    +     * @return The argument of a "$match" query
    +     */
    +    private static BasicDBObject getMatchExpression(StatementPattern sp, 
String ... path) {
    +        final Var subjVar = sp.getSubjectVar();
    +        final Var predVar = sp.getPredicateVar();
    +        final Var objVar = sp.getObjectVar();
    +        final Var contextVar = sp.getContextVar();
    +        RyaURI s = null;
    +        RyaURI p = null;
    +        RyaType o = null;
    +        RyaURI c = null;
    +        if (subjVar != null && subjVar.getValue() instanceof Resource) {
    +            s = RdfToRyaConversions.convertResource((Resource) 
subjVar.getValue());
    +        }
    +        if (predVar != null && predVar.getValue() instanceof URI) {
    +            p = RdfToRyaConversions.convertURI((URI) predVar.getValue());
    +        }
    +        if (objVar != null && objVar.getValue() != null) {
    +            o = RdfToRyaConversions.convertValue(objVar.getValue());
    +        }
    +        if (contextVar != null && contextVar.getValue() instanceof URI) {
    +            c = RdfToRyaConversions.convertURI((URI) 
contextVar.getValue());
    +        }
    +        RyaStatement rs = new RyaStatement(s, p, o, c);
    +        DBObject obj = strategy.getQuery(rs);
    +        // Add path prefix, if given
    +        if (path.length > 0) {
    +            StringBuilder sb = new StringBuilder();
    +            for (String str : path) {
    +                sb.append(str).append(".");
    +            }
    +            String prefix = sb.toString();
    +            Set<String> originalKeys = new HashSet<>(obj.keySet());
    +            originalKeys.forEach(key -> {
    +                Object value = obj.removeField(key);
    +                obj.put(prefix + key, value);
    +            });
    +        }
    +        return (BasicDBObject) obj;
    +    }
    +
    +    private static String valueFieldExpr(String varName) {
    +        return "$" + VALUES + "." + varName;
    +    }
    +    private static String hashFieldExpr(String varName) {
    +        return "$" + HASHES + "." + varName;
    +    }
    +    private static String typeFieldExpr(String varName) {
    +        return "$" + TYPES + "." + varName;
    +    }
    +    private static String joinFieldExpr(String triplePart) {
    +        return "$" + JOINED_TRIPLE + "." + triplePart;
    +    }
    +
    +    /**
    +     * Get an object representing the value field of some value 
expression, or
    +     * return null if the expression isn't supported.
    +     */
    +    private Object valueFieldExpr(ValueExpr expr) {
    +        if (expr instanceof Var) {
    +            return valueFieldExpr(((Var) expr).getName());
    +        }
    +        else if (expr instanceof ValueConstant) {
    +            return new Document("$literal", ((ValueConstant) 
expr).getValue().stringValue());
    +        }
    +        else {
    +            return null;
    +        }
    +    }
    +
    +    private final List<Bson> pipeline;
    +    private final MongoCollection<Document> collection;
    +    private final Set<String> assuredBindingNames;
    +    private final Set<String> bindingNames;
    +    private final BiMap<String, String> varToOriginalName;
    +
    +    private String replace(String original) {
    +        if (varToOriginalName.containsValue(original)) {
    +            return varToOriginalName.inverse().get(original);
    +        }
    +        else {
    +            String replacement = "field-" + UUID.randomUUID();
    +            varToOriginalName.put(replacement, original);
    +            return replacement;
    +        }
    +    }
    +
    +    /**
    +     * Create a pipeline based on a StatementPattern.
    +     * @param baseSP The leaf node in the query tree.
    +     */
    +    public AggregationPipelineQueryNode(MongoCollection<Document> 
collection, StatementPattern baseSP) {
    +        Preconditions.checkNotNull(collection);
    +        Preconditions.checkNotNull(baseSP);
    +        this.collection = collection;
    +        this.varToOriginalName = HashBiMap.create();
    +        StatementVarMapping mapping = new StatementVarMapping(baseSP, 
varToOriginalName);
    +        this.assuredBindingNames = new HashSet<>(mapping.varNames());
    +        this.bindingNames = new HashSet<>(mapping.varNames());
    +        this.pipeline = new LinkedList<>();
    +        this.pipeline.add(Aggregates.match(getMatchExpression(baseSP)));
    +        
this.pipeline.add(Aggregates.project(mapping.getProjectExpression()));
    +    }
    +
    +    AggregationPipelineQueryNode(MongoCollection<Document> collection,
    +            List<Bson> pipeline, Set<String> assuredBindingNames,
    +            Set<String> bindingNames, BiMap<String, String> 
varToOriginalName) {
    +        Preconditions.checkNotNull(collection);
    +        Preconditions.checkNotNull(pipeline);
    +        Preconditions.checkNotNull(assuredBindingNames);
    +        Preconditions.checkNotNull(bindingNames);
    +        Preconditions.checkNotNull(varToOriginalName);
    +        this.collection = collection;
    +        this.pipeline = pipeline;
    +        this.assuredBindingNames = assuredBindingNames;
    +        this.bindingNames = bindingNames;
    +        this.varToOriginalName = varToOriginalName;
    +    }
    +
    +    @Override
    +    public boolean equals(Object o) {
    +        if (this == o) {
    +            return true;
    +        }
    +        if (o instanceof AggregationPipelineQueryNode) {
    +            AggregationPipelineQueryNode other = 
(AggregationPipelineQueryNode) o;
    +            if (this.collection.equals(other.collection)
    +                    && 
this.assuredBindingNames.equals(other.assuredBindingNames)
    +                    && this.bindingNames.equals(other.bindingNames)
    +                    && 
this.varToOriginalName.equals(other.varToOriginalName)
    +                    && this.pipeline.size() == other.pipeline.size()) {
    +                // Check pipeline steps for equality -- underlying types 
don't
    +                // have well-behaved equals methods, so check for 
equivalent
    +                // string representations.
    +                for (int i = 0; i < this.pipeline.size(); i++) {
    +                    Bson doc1 = this.pipeline.get(i);
    +                    Bson doc2 = other.pipeline.get(i);
    +                    if (!doc1.toString().equals(doc2.toString())) {
    +                        return false;
    +                    }
    +                }
    +                return true;
    +            }
    +        }
    +        return false;
    +    }
    +
    +    @Override
    +    public int hashCode() {
    +        int result = collection.hashCode();
    +        for (Bson step : pipeline) {
    +            result = result * 37 + step.toString().hashCode();
    +        }
    +        result = result * 37 + assuredBindingNames.hashCode();
    +        result = result * 37 + bindingNames.hashCode();
    +        result = result * 37 + varToOriginalName.hashCode();
    +        return result;
    +    }
    +
    +    @Override
    +    public CloseableIteration<BindingSet, QueryEvaluationException> 
evaluate(BindingSet bindings)
    +            throws QueryEvaluationException {
    +        return new PipelineResultIteration(collection.aggregate(pipeline), 
varToOriginalName, bindings);
    +    }
    +
    +    @Override
    +    public Set<String> getAssuredBindingNames() {
    +        Set<String> names = new HashSet<>();
    +        for (String name : assuredBindingNames) {
    +            names.add(varToOriginalName.getOrDefault(name, name));
    +        }
    +        return names;
    +    }
    +
    +    @Override
    +    public Set<String> getBindingNames() {
    +        Set<String> names = new HashSet<>();
    +        for (String name : bindingNames) {
    +            names.add(varToOriginalName.getOrDefault(name, name));
    +        }
    +        return names;
    +    }
    +
    +    @Override
    +    public AggregationPipelineQueryNode clone() {
    +        return new AggregationPipelineQueryNode(collection,
    +                new LinkedList<>(pipeline),
    +                new HashSet<>(assuredBindingNames),
    +                new HashSet<>(bindingNames),
    +                HashBiMap.create(varToOriginalName));
    +    }
    +
    +    @Override
    +    public String getSignature() {
    +        super.getSignature();
    +        Set<String> assured = getAssuredBindingNames();
    +        Set<String> any = getBindingNames();
    +        StringBuilder sb = new StringBuilder("AggregationPipelineQueryNode 
(binds: ");
    +        sb.append(String.join(", ", assured));
    +        if (any.size() > assured.size()) {
    +            Set<String> optionalBindingNames = any;
    +            optionalBindingNames.removeAll(assured);
    +            sb.append(" [")
    +                .append(String.join(", ", optionalBindingNames))
    +                .append("]");
    +        }
    +        sb.append(")\n");
    +        for (Bson doc : pipeline) {
    +            sb.append(doc.toString()).append("\n");
    +        }
    +        return sb.toString();
    +    }
    +
    +    /**
    +     * Get the internal list of aggregation pipeline steps. Note that 
documents
    +     * resulting from this pipeline will be structured using an internal
    +     * intermediate representation. For documents representing triples, see
    +     * {@link #getTriplePipeline}, and for query solutions, see
    +     * {@link #evaluate}.
    +     * @return The current internal pipeline.
    +     */
    +    List<Bson> getPipeline() {
    +        return pipeline;
    +    }
    +
    +    /**
    +     * Add a join with an individual {@link StatementPattern} to the 
pipeline.
    +     * @param sp The statement pattern to join with
    +     * @return true if the join was successfully added to the pipeline.
    +     */
    +    public boolean joinWith(StatementPattern sp) {
    --- End diff --
    
    null check


---

Reply via email to