This is an automated email from the ASF dual-hosted git repository. volodymyr pushed a commit to branch mongo in repository https://gitbox.apache.org/repos/asf/drill.git
commit 40d6f1e56ad64c665c69db73efd1c51fba9cd42d Author: Volodymyr Vysotskyi <[email protected]> AuthorDate: Thu Jul 8 21:31:37 2021 +0300 Initial changes --- .../store/druid/DruidPushDownFilterForScan.java | 2 +- .../drill/exec/store/mongo/MongoGroupScan.java | 3 +- .../store/mongo/MongoPushDownAggregateForScan.java | 301 +++++++++++++++++++++ .../store/mongo/MongoPushDownFilterForScan.java | 2 +- .../drill/exec/store/mongo/MongoRecordReader.java | 29 +- .../drill/exec/store/mongo/MongoScanSpec.java | 31 ++- .../drill/exec/store/mongo/MongoStoragePlugin.java | 24 +- .../drill/exec/store/mongo/MongoSubScan.java | 10 + .../drill/exec/store/mongo/TestMongoQueries.java | 56 ++++ .../apache/drill/exec/planner/PlannerPhase.java | 1 + .../exec/planner/common/DrillScanRelBase.java | 4 +- .../logical/DrillPushProjectIntoScanRule.java | 26 +- .../drill/exec/planner/logical/DrillScanRel.java | 4 +- .../drill/exec/planner/physical/ScanPrel.java | 4 +- .../exec/planner/physical/StreamAggPrule.java | 1 + 15 files changed, 468 insertions(+), 30 deletions(-) diff --git a/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java b/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java index 65d95aa..2c5fcee 100644 --- a/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java +++ b/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java @@ -73,7 +73,7 @@ public class DruidPushDownFilterForScan extends StoragePluginOptimizerRule { groupScan.getMaxRecordsToRead()); newGroupsScan.setFilterPushedDown(true); - ScanPrel newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan); + ScanPrel newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan, filter.getRowType()); if (druidFilterBuilder.isAllExpressionsConverted()) { /* * Since we could convert the entire filter condition expression into a diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java index 8b57012..6662e8c 100644 --- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java +++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java @@ -466,7 +466,8 @@ public class MongoGroupScan extends AbstractGroupScan implements .setMinFilters(chunkInfo.getMinFilters()) .setMaxFilters(chunkInfo.getMaxFilters()) .setMaxRecords(maxRecords) - .setFilter(scanSpec.getFilters()); + .setFilter(scanSpec.getFilters()) + .setAggregates(scanSpec.getAggregates()); } @Override diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java new file mode 100644 index 0000000..f7e1a00 --- /dev/null +++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java @@ -0,0 +1,301 @@ +package org.apache.drill.exec.store.mongo; + +import org.apache.calcite.avatica.util.DateTimeUtils; +import org.apache.calcite.linq4j.function.Function1; +import org.apache.calcite.linq4j.tree.Primitive; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptRuleOperand; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.fun.SqlSumAggFunction; +import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; +import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.Util; +import org.apache.drill.common.exceptions.DrillRuntimeException; +import org.apache.drill.common.expression.SchemaPath; +import org.apache.drill.exec.planner.common.DrillScanRelBase; +import org.apache.drill.exec.planner.logical.RelOptHelper; +import org.apache.drill.exec.store.StoragePluginOptimizerRule; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.Document; +import org.bson.conversions.Bson; + +import java.io.IOException; +import java.util.AbstractList; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class MongoPushDownAggregateForScan extends StoragePluginOptimizerRule { + public static final StoragePluginOptimizerRule INSTANCE = new MongoPushDownAggregateForScan(RelOptHelper.some(Aggregate.class, RelOptHelper.any(DrillScanRelBase.class)), "MongoPushDownAggregateForScan"); + public static final StoragePluginOptimizerRule PROJ_INSTANCE = new MongoPushDownAggregateForScan(RelOptHelper.some(Aggregate.class, RelOptHelper.some(Project.class, RelOptHelper.any(DrillScanRelBase.class))), "MongoPushDownAggregateForScan_project"); + + public MongoPushDownAggregateForScan(RelOptRuleOperand operand, String desc) { + super(operand, desc); + } + + static List<String> mongoFieldNames(final RelDataType rowType) { + return SqlValidatorUtil.uniquify( + new AbstractList<String>() { + @Override public String get(int index) { + final String name = rowType.getFieldList().get(index).getName(); + return name.startsWith("$") ? "_" + name.substring(2) : name; + } + + @Override public int size() { + return rowType.getFieldCount(); + } + }, + SqlValidatorUtil.EXPR_SUGGESTER, true); + } + + static String maybeQuote(String s) { + if (!needsQuote(s)) { + return s; + } + return quote(s); + } + + static String quote(String s) { + return "'" + s + "'"; // TODO: handle embedded quotes + } + + private static boolean needsQuote(String s) { + for (int i = 0, n = s.length(); i < n; i++) { + char c = s.charAt(i); + if (!Character.isJavaIdentifierPart(c) + || c == '$') { + return true; + } + } + return false; + } + + @Override + public void onMatch(RelOptRuleCall call) { + Aggregate aggregate = call.rel(0); + DrillScanRelBase scan = call.rel(1); + + MongoGroupScan groupScan = (MongoGroupScan) scan.getGroupScan(); + + +// implementor.visitChild(0, getInput()); + List<String> list = new ArrayList<>(); +// List<BsonDocument> docList = new ArrayList<>(); + final List<String> inNames = + mongoFieldNames(scan.getRowType()); + final List<String> outNames = mongoFieldNames(aggregate.getRowType()); + int i = 0; + if (aggregate.getGroupSet().cardinality() == 1) { + final String inName = inNames.get(aggregate.getGroupSet().nth(0)); + list.add("_id: " + maybeQuote("$" + inName)); +// docList.add(new BsonDocument("_id", new BsonString(maybeQuote("$" + inName)))); + ++i; + } else { + List<String> keys = new ArrayList<>(); + for (int group : aggregate.getGroupSet()) { + final String inName = inNames.get(group); + keys.add(inName + ": " + quote("$" + inName)); + ++i; + } + list.add("_id: " + Util.toString(keys, "{", ", ", "}")); +// docList.add(new BsonDocument("_id", new BsonString(Util.toString(keys, "{", ", ", "}")))); + } + for (AggregateCall aggCall : aggregate.getAggCallList()) { + list.add( + maybeQuote(outNames.get(i++)) + ": " + + toMongo(aggCall.getAggregation(), inNames, aggCall.getArgList())); + } + List<Pair<String, String>> aggsList = new ArrayList<>(); + aggsList.add(Pair.of(null, "{$group: " + Util.toString(list, "{", ", ", "}") + "}")); + final List<String> fixups; + if (aggregate.getGroupSet().cardinality() == 1) { + fixups = new AbstractList<String>() { + @Override public String get(int index) { + final String outName = outNames.get(index); + return maybeQuote(outName) + ": " + + maybeQuote("$" + (index == 0 ? "_id" : outName)); + } + + @Override public int size() { + return outNames.size(); + } + }; + } else { + fixups = new ArrayList<>(); + fixups.add("_id: 0"); + i = 0; + for (int group : aggregate.getGroupSet()) { + fixups.add( + maybeQuote(outNames.get(group)) + + ": " + + maybeQuote("$_id." + outNames.get(group))); + ++i; + } + for (AggregateCall ignored : aggregate.getAggCallList()) { + final String outName = outNames.get(i++); + fixups.add( + maybeQuote(outName) + ": " + maybeQuote( + "$" + outName)); + } + } + if (!aggregate.getGroupSet().isEmpty()) { + aggsList.add(Pair.of(null, "{$project: " + Util.toString(fixups, "{", ", ", "}") + "}")); + } + + MongoScanSpec mongoScanSpec = aggregate(groupScan.getScanSpec(), Pair.right(aggsList)); + try { + List<SchemaPath> columns = outNames.stream() + .map(SchemaPath::getSimplePath) + .collect(Collectors.toList()); + MongoGroupScan mongoScanSpec123 = new MongoGroupScan(groupScan.getUserName(), groupScan.getStoragePlugin(), + mongoScanSpec, columns, groupScan.getMaxRecords()); + call.transformTo(scan.copy(aggregate.getTraitSet(), mongoScanSpec123, aggregate.getRowType())); + } catch (IOException e) { + throw new DrillRuntimeException(e.getMessage(), e); + } + } + + private static String toMongo(SqlAggFunction aggregation, List<String> inNames, + List<Integer> args) { + if (aggregation.getName().equals(SqlStdOperatorTable.COUNT.getName())) { + if (args.size() == 0) { +// Aggregates.count() + return "{$sum: 1}"; + } else { + assert args.size() == 1; +// Arrays.asList( +// Aggregates.match(Filters.eq("languages.name", "English")), +// Aggregates.count()) + final String inName = inNames.get(args.get(0)); + return "{$sum: {$cond: [ {$eq: [" + + quote(inName) + + ", null]}, 0, 1]}}"; + } + } else if (aggregation instanceof SqlSumAggFunction + || aggregation instanceof SqlSumEmptyIsZeroAggFunction) { + assert args.size() == 1; + final String inName = inNames.get(args.get(0)); + return "{$sum: " + maybeQuote("$" + inName) + "}"; + } else if (aggregation.getName().equals(SqlStdOperatorTable.MIN.getName())) { + assert args.size() == 1; + final String inName = inNames.get(args.get(0)); + return "{$min: " + maybeQuote("$" + inName) + "}"; + } else if (aggregation.getName().equals(SqlStdOperatorTable.MAX.getName())) { + assert args.size() == 1; + final String inName = inNames.get(args.get(0)); + return "{$max: " + maybeQuote("$" + inName) + "}"; + } else if (aggregation.getName().equals(SqlStdOperatorTable.AVG.getName())) { + assert args.size() == 1; + final String inName = inNames.get(args.get(0)); + return "{$avg: " + maybeQuote("$" + inName) + "}"; + } else { + throw new AssertionError("unknown aggregate " + aggregation); + } + } + + private MongoScanSpec aggregate(MongoScanSpec scanSpec, + final List<String> operations) { + final List<Bson> list = new ArrayList<>(); + for (String operation : operations) { + list.add(BsonDocument.parse(operation)); + } + return new MongoScanSpec(scanSpec.getDbName(), scanSpec.getCollectionName(), + scanSpec.getFilters(), list); +// final Function1<Document, Object> getter = +// getter(fields); +// return new AbstractEnumerable<Object>() { +// @Override public Enumerator<Object> enumerator() { +// final Iterator<Document> resultIterator; +// try { +// resultIterator = mongoDb.getCollection(scanSpec.getCollectionName()) +// .aggregate(list).iterator(); +// } catch (Exception e) { +// throw new RuntimeException("While running MongoDB query " +// + Util.toString(operations, "[", ",\n", "]"), e); +// } +// return new MongoEnumerator(resultIterator, getter); +// } +// }; + } + + static Function1<Document, Map> mapGetter() { + return a0 -> (Map) a0; + } + + /** Returns a function that projects a single field. */ + static Function1<Document, Object> singletonGetter(final String fieldName, + final Class fieldClass) { + return a0 -> convert(a0.get(fieldName), fieldClass); + } + + /** Returns a function that projects fields. + * + * @param fields List of fields to project; or null to return map + */ + static Function1<Document, Object[]> listGetter( + final List<Map.Entry<String, Class>> fields) { + return a0 -> { + Object[] objects = new Object[fields.size()]; + for (int i = 0; i < fields.size(); i++) { + final Map.Entry<String, Class> field = fields.get(i); + final String name = field.getKey(); + objects[i] = convert(a0.get(name), field.getValue()); + } + return objects; + }; + } + + static Function1<Document, Object> getter( + List<Map.Entry<String, Class>> fields) { + //noinspection unchecked + return fields == null + ? (Function1) mapGetter() + : fields.size() == 1 + ? singletonGetter(fields.get(0).getKey(), fields.get(0).getValue()) + : (Function1) listGetter(fields); + } + + @SuppressWarnings("JavaUtilDate") + private static Object convert(Object o, Class clazz) { + if (o == null) { + return null; + } + Primitive primitive = Primitive.of(clazz); + if (primitive != null) { + clazz = primitive.boxClass; + } else { + primitive = Primitive.ofBox(clazz); + } + if (clazz.isInstance(o)) { + return o; + } + if (o instanceof Date && primitive != null) { + o = ((Date) o).getTime() / DateTimeUtils.MILLIS_PER_DAY; + } + if (o instanceof Number && primitive != null) { + return primitive.number((Number) o); + } + return o; + } + + //$addToSet + //$avg + //$first + //$last + //$max + //$min + //$mergeObjects + //$push + //$stdDevPop + //$stdDevSamp + //$sum +} diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java index 5e57890..b1c06e7 100644 --- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java +++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java @@ -76,7 +76,7 @@ public class MongoPushDownFilterForScan extends StoragePluginOptimizerRule { } newGroupsScan.setFilterPushedDown(true); - RelNode newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan); + RelNode newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan, filter.getRowType()); if (mongoFilterBuilder.isAllExpressionsConverted()) { /* diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java index b06fe36..7c4f3f2 100644 --- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java +++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java @@ -18,6 +18,7 @@ package org.apache.drill.exec.store.mongo; import java.io.IOException; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; @@ -25,6 +26,9 @@ import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.TimeUnit; +import com.mongodb.client.FindIterable; +import com.mongodb.client.model.Aggregates; +import org.apache.commons.collections.CollectionUtils; import org.apache.drill.common.exceptions.DrillRuntimeException; import org.apache.drill.common.exceptions.ExecutionSetupException; import org.apache.drill.common.expression.SchemaPath; @@ -40,6 +44,7 @@ import org.apache.drill.exec.vector.complex.impl.VectorContainerWriter; import org.bson.BsonDocument; import org.bson.BsonDocumentReader; import org.bson.Document; +import org.bson.conversions.Bson; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -64,6 +69,7 @@ public class MongoRecordReader extends AbstractRecordReader { private VectorContainerWriter writer; private Document filters; + private List<Bson> aggregates; private final Document fields; private final FragmentContext fragmentContext; @@ -87,6 +93,7 @@ public class MongoRecordReader extends AbstractRecordReader { fragmentContext = context; this.plugin = plugin; filters = new Document(); + aggregates = subScanSpec.aggregates; Map<String, List<Document>> mergedFilters = MongoUtils.mergeFilters( subScanSpec.getMinFilters(), subScanSpec.getMaxFilters()); @@ -176,12 +183,24 @@ public class MongoRecordReader extends AbstractRecordReader { logger.debug("Filters Applied : " + filters); logger.debug("Fields Selected :" + fields); - // Add limit to Mongo query - if (maxRecords > 0) { - logger.debug("Limit applied: {}", maxRecords); - cursor = collection.find(filters).projection(fields).limit(maxRecords).batchSize(100).iterator(); + if (CollectionUtils.isNotEmpty(aggregates)) { + List<Bson> operations = new ArrayList<>(); + operations.add(Aggregates.match(filters)); + operations.addAll(aggregates); + operations.add(Aggregates.project(fields)); + if (maxRecords > 0) { + operations.add(Aggregates.limit(maxRecords)); + } + cursor = collection.aggregate(operations).batchSize(100).iterator(); } else { - cursor = collection.find(filters).projection(fields).batchSize(100).iterator(); + // Add limit to Mongo query + FindIterable<BsonDocument> projection = collection.find(filters).projection(fields); + if (maxRecords > 0) { + logger.debug("Limit applied: {}", maxRecords); + projection = projection.limit(maxRecords); + } + + cursor = projection.batchSize(100).iterator(); } } diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java index 5c56fcc..7ec1210 100644 --- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java +++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java @@ -21,13 +21,19 @@ import org.bson.Document; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import org.bson.conversions.Bson; + +import java.util.List; +import java.util.StringJoiner; public class MongoScanSpec { - private String dbName; - private String collectionName; + private final String dbName; + private final String collectionName; private Document filters; + private List<Bson> aggregates; + @JsonCreator public MongoScanSpec(@JsonProperty("dbName") String dbName, @JsonProperty("collectionName") String collectionName) { @@ -42,6 +48,14 @@ public class MongoScanSpec { this.filters = filters; } + public MongoScanSpec(String dbName, String collectionName, + Document filters, List<Bson> aggregates) { + this.dbName = dbName; + this.collectionName = collectionName; + this.filters = filters; + this.aggregates = aggregates; + } + public String getDbName() { return dbName; } @@ -54,10 +68,17 @@ public class MongoScanSpec { return filters; } + public List<Bson> getAggregates() { + return aggregates; + } + @Override public String toString() { - return "MongoScanSpec [dbName=" + dbName + ", collectionName=" - + collectionName + ", filters=" + filters + "]"; + return new StringJoiner(", ", MongoScanSpec.class.getSimpleName() + "[", "]") + .add("dbName='" + dbName + "'") + .add("collectionName='" + collectionName + "'") + .add("filters=" + filters) + .add("aggregates=" + aggregates) + .toString(); } - } diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java index da55907..f6c3ac2 100644 --- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java +++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java @@ -25,16 +25,16 @@ import com.mongodb.client.MongoClient; import com.mongodb.MongoCredential; import com.mongodb.ServerAddress; import com.mongodb.client.MongoClients; +import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.schema.SchemaPlus; import org.apache.drill.common.JSONOptions; import org.apache.drill.common.exceptions.DrillRuntimeException; -import org.apache.drill.common.exceptions.ExecutionSetupException; import org.apache.drill.exec.ops.OptimizerRulesContext; import org.apache.drill.exec.physical.base.AbstractGroupScan; +import org.apache.drill.exec.planner.PlannerPhase; import org.apache.drill.exec.server.DrillbitContext; import org.apache.drill.exec.store.AbstractStoragePlugin; import org.apache.drill.exec.store.SchemaConfig; -import org.apache.drill.exec.store.StoragePluginOptimizerRule; import org.apache.drill.exec.store.mongo.schema.MongoSchemaFactory; import org.apache.drill.common.logical.security.CredentialsProvider; import org.apache.drill.exec.store.security.HadoopCredentialsProvider; @@ -52,6 +52,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.URLEncoder; +import java.util.Collections; import java.util.List; import java.util.Set; import java.util.concurrent.ExecutionException; @@ -68,7 +69,7 @@ public class MongoStoragePlugin extends AbstractStoragePlugin { public MongoStoragePlugin( MongoStoragePluginConfig mongoConfig, DrillbitContext context, - String name) throws ExecutionSetupException { + String name) { super(context, name); this.mongoConfig = mongoConfig; String connection = addCredentialsFromCredentialsProvider(this.mongoConfig.getConnection(), name); @@ -120,7 +121,7 @@ public class MongoStoragePlugin extends AbstractStoragePlugin { } @Override - public void registerSchemas(SchemaConfig schemaConfig, SchemaPlus parent) throws IOException { + public void registerSchemas(SchemaConfig schemaConfig, SchemaPlus parent) { schemaFactory.registerSchemas(schemaConfig, parent); } @@ -137,8 +138,19 @@ public class MongoStoragePlugin extends AbstractStoragePlugin { } @Override - public Set<StoragePluginOptimizerRule> getPhysicalOptimizerRules(OptimizerRulesContext optimizerRulesContext) { - return ImmutableSet.of(MongoPushDownFilterForScan.INSTANCE); + public Set<? extends RelOptRule> getOptimizerRules(OptimizerRulesContext optimizerContext, PlannerPhase phase) { + switch (phase) { + case PHYSICAL: + case LOGICAL: + return ImmutableSet.of(MongoPushDownFilterForScan.INSTANCE, + MongoPushDownAggregateForScan.INSTANCE); + case LOGICAL_PRUNE_AND_JOIN: + case LOGICAL_PRUNE: + case PARTITION_PRUNING: + case JOIN_PLANNING: + default: + return Collections.emptySet(); + } } diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java index af13eb5..a32336d 100644 --- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java +++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java @@ -40,6 +40,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import org.apache.drill.shaded.guava.com.google.common.base.Preconditions; +import org.bson.conversions.Bson; @JsonTypeName("mongo-shard-read") public class MongoSubScan extends AbstractBase implements SubScan { @@ -132,6 +133,7 @@ public class MongoSubScan extends AbstractBase implements SubScan { protected int maxRecords; protected Document filter; + protected List<Bson> aggregates; @JsonCreator public MongoSubScanSpec(@JsonProperty("dbName") String dbName, @@ -140,6 +142,7 @@ public class MongoSubScan extends AbstractBase implements SubScan { @JsonProperty("minFilters") Map<String, Object> minFilters, @JsonProperty("maxFilters") Map<String, Object> maxFilters, @JsonProperty("filters") Document filters, + @JsonProperty("aggregates") List<Bson> aggregates, @JsonProperty("maxRecords") int maxRecords) { this.dbName = dbName; this.collectionName = collectionName; @@ -147,6 +150,7 @@ public class MongoSubScan extends AbstractBase implements SubScan { this.minFilters = minFilters; this.maxFilters = maxFilters; this.filter = filters; + this.aggregates = aggregates; this.maxRecords = maxRecords; } @@ -215,6 +219,11 @@ public class MongoSubScan extends AbstractBase implements SubScan { return this; } + public MongoSubScanSpec setAggregates(List<Bson> aggregates) { + this.aggregates = aggregates; + return this; + } + @Override public String toString() { return new PlanStringBuilder(this) @@ -224,6 +233,7 @@ public class MongoSubScan extends AbstractBase implements SubScan { .field("minFilters", minFilters) .field("maxFilters", maxFilters) .field("filter", filter) + .field("aggregates", aggregates) .field("maxRecords", maxRecords) .toString(); diff --git a/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java b/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java index 4b20ebc..d5316f9 100644 --- a/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java +++ b/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java @@ -105,4 +105,60 @@ public class TestMongoQueries extends MongoTestBase { .expectsNumRecords(5) .go(); } + + @Test + public void testCountColumnPushDown() throws Exception { + String query = "select count(t.name) as c from mongo.%s.`%s` t"; + + queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION) + .planMatcher() + .exclude("Agg\\(") + .include("Scan\\(.*aggregates") + .match(); + + testBuilder() + .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION) + .unOrdered() + .baselineColumns("c") + .baselineValues(5) + .go(); + } + + @Test + public void testCountGroupByPushDown() throws Exception { + String query = "select count(t.id) as c, t.type from mongo.%s.`%s` t group by t.type"; + + queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION) + .planMatcher() + .exclude("Agg\\(") + .include("Scan\\(.*aggregates") + .match(); + + testBuilder() + .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION) + .unOrdered() + .baselineColumns("c", "type") + .baselineValues(5, "donut") + .go(); + } + + @Test + public void testCountColumnPushDownWithFilter() throws Exception { + String query = "select count(t.id) as c from mongo.%s.`%s` t where t.name = 'Cake'"; + + queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION) + .planMatcher() + .exclude("Agg\\(", "Filter") + .include("Scan\\(.*aggregates") + .match(); + + testBuilder() + .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION) + .unOrdered() + .baselineColumns("c") + .baselineValues(1) + .go(); + +// queryBuilder().sql("select * from mongo.%s.`%s` t", DONUTS_DB, DONUTS_COLLECTION).printCsv(); + } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java index 97b34e1..cb0fd3c 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java @@ -347,6 +347,7 @@ public enum PlannerPhase { // RuleInstance.PROJECT_SET_OP_TRANSPOSE_RULE, RuleInstance.PROJECT_WINDOW_TRANSPOSE_RULE, DrillPushProjectIntoScanRule.INSTANCE, + DrillPushProjectIntoScanRule.LOGICAL_INSTANCE, DrillPushProjectIntoScanRule.DRILL_LOGICAL_INSTANCE, /* diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java index fe67709..a307f93 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java @@ -19,6 +19,8 @@ package org.apache.drill.exec.planner.common; import java.io.IOException; import java.util.List; + +import org.apache.calcite.rel.type.RelDataType; import org.apache.drill.common.expression.SchemaPath; import org.apache.drill.exec.physical.base.GroupScan; import org.apache.drill.exec.planner.logical.DrillTable; @@ -87,5 +89,5 @@ public abstract class DrillScanRelBase extends TableScan implements DrillRelNode return planner.getCostFactory().makeCost(dRows, dCpu, dIo); } - public abstract DrillScanRelBase copy(RelTraitSet traitSet, GroupScan scan); + public abstract DrillScanRelBase copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType); } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java index 91875bb..d54bb42 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java @@ -56,11 +56,16 @@ public class DrillPushProjectIntoScanRule extends RelOptRule { } }; - public static final RelOptRule DRILL_LOGICAL_INSTANCE = + public static final RelOptRule LOGICAL_INSTANCE = new DrillPushProjectIntoScanRule(LogicalProject.class, DrillScanRel.class, "DrillPushProjectIntoScanRule:logical"); + public static final RelOptRule DRILL_LOGICAL_INSTANCE = + new DrillPushProjectIntoScanRule(DrillProjectRel.class, + DrillScanRel.class, + "DrillPushProjectIntoScanRule:drill_logical"); + public static final RelOptRule DRILL_PHYSICAL_INSTANCE = new DrillPushProjectIntoScanRule(ProjectPrel.class, ScanPrel.class, @@ -167,11 +172,20 @@ public class DrillPushProjectIntoScanRule extends RelOptRule { * @return new scan instance */ protected TableScan createScan(TableScan scan, ProjectPushInfo projectPushInfo) { - return new DrillScanRel(scan.getCluster(), - scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL), - scan.getTable(), - projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()), - projectPushInfo.getFields()); + if (scan instanceof DrillScanRel) { + return new DrillScanRel(scan.getCluster(), + scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL), + scan.getTable(), + ((DrillScanRel) scan).getGroupScan().clone(projectPushInfo.getFields()), + projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()), + projectPushInfo.getFields()); + } else { + return new DrillScanRel(scan.getCluster(), + scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL), + scan.getTable(), + projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()), + projectPushInfo.getFields()); + } } /** diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java index 26ef4ea..bcd9792 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java @@ -193,7 +193,7 @@ public class DrillScanRel extends DrillScanRelBase implements DrillRel { } @Override - public DrillScanRel copy(RelTraitSet traitSet, GroupScan scan) { - return new DrillScanRel(getCluster(), getTraitSet(), getTable(), scan, getRowType(), getColumns(), partitionFilterPushdown()); + public DrillScanRel copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType) { + return new DrillScanRel(getCluster(), getTraitSet(), getTable(), scan, rowType, getColumns(), partitionFilterPushdown()); } } diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java index 50996b9..1e0bdf4 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java @@ -60,8 +60,8 @@ public class ScanPrel extends DrillScanRelBase implements LeafPrel, HasDistribut } @Override - public ScanPrel copy(RelTraitSet traitSet, GroupScan scan) { - return new ScanPrel(getCluster(), traitSet, scan, getRowType(), getTable()); + public ScanPrel copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType) { + return new ScanPrel(getCluster(), traitSet, scan, rowType, getTable()); } @Override diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java index 0b68014..9fd789c 100644 --- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java +++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java @@ -54,6 +54,7 @@ public class StreamAggPrule extends AggPruleBase { @Override public void onMatch(RelOptRuleCall call) { + final DrillAggregateRel aggregate = call.rel(0); RelNode input = aggregate.getInput(); final RelCollation collation = getCollation(aggregate);
