You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by vo...@apache.org on 2021/08/05 17:23:14 UTC
[drill] 01/13: Initial changes
This is an automated email from the ASF dual-hosted git repository.
volodymyr pushed a commit to branch mongo
in repository https://gitbox.apache.org/repos/asf/drill.git
commit 40d6f1e56ad64c665c69db73efd1c51fba9cd42d
Author: Volodymyr Vysotskyi <vv...@gmail.com>
AuthorDate: Thu Jul 8 21:31:37 2021 +0300
Initial changes
---
.../store/druid/DruidPushDownFilterForScan.java | 2 +-
.../drill/exec/store/mongo/MongoGroupScan.java | 3 +-
.../store/mongo/MongoPushDownAggregateForScan.java | 301 +++++++++++++++++++++
.../store/mongo/MongoPushDownFilterForScan.java | 2 +-
.../drill/exec/store/mongo/MongoRecordReader.java | 29 +-
.../drill/exec/store/mongo/MongoScanSpec.java | 31 ++-
.../drill/exec/store/mongo/MongoStoragePlugin.java | 24 +-
.../drill/exec/store/mongo/MongoSubScan.java | 10 +
.../drill/exec/store/mongo/TestMongoQueries.java | 56 ++++
.../apache/drill/exec/planner/PlannerPhase.java | 1 +
.../exec/planner/common/DrillScanRelBase.java | 4 +-
.../logical/DrillPushProjectIntoScanRule.java | 26 +-
.../drill/exec/planner/logical/DrillScanRel.java | 4 +-
.../drill/exec/planner/physical/ScanPrel.java | 4 +-
.../exec/planner/physical/StreamAggPrule.java | 1 +
15 files changed, 468 insertions(+), 30 deletions(-)
diff --git a/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java b/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java
index 65d95aa..2c5fcee 100644
--- a/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java
+++ b/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java
@@ -73,7 +73,7 @@ public class DruidPushDownFilterForScan extends StoragePluginOptimizerRule {
groupScan.getMaxRecordsToRead());
newGroupsScan.setFilterPushedDown(true);
- ScanPrel newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan);
+ ScanPrel newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan, filter.getRowType());
if (druidFilterBuilder.isAllExpressionsConverted()) {
/*
* Since we could convert the entire filter condition expression into a
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java
index 8b57012..6662e8c 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java
@@ -466,7 +466,8 @@ public class MongoGroupScan extends AbstractGroupScan implements
.setMinFilters(chunkInfo.getMinFilters())
.setMaxFilters(chunkInfo.getMaxFilters())
.setMaxRecords(maxRecords)
- .setFilter(scanSpec.getFilters());
+ .setFilter(scanSpec.getFilters())
+ .setAggregates(scanSpec.getAggregates());
}
@Override
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java
new file mode 100644
index 0000000..f7e1a00
--- /dev/null
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java
@@ -0,0 +1,301 @@
+package org.apache.drill.exec.store.mongo;
+
+import org.apache.calcite.avatica.util.DateTimeUtils;
+import org.apache.calcite.linq4j.function.Function1;
+import org.apache.calcite.linq4j.tree.Primitive;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptRuleOperand;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.fun.SqlSumAggFunction;
+import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
+import org.apache.calcite.sql.validate.SqlValidatorUtil;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.Util;
+import org.apache.drill.common.exceptions.DrillRuntimeException;
+import org.apache.drill.common.expression.SchemaPath;
+import org.apache.drill.exec.planner.common.DrillScanRelBase;
+import org.apache.drill.exec.planner.logical.RelOptHelper;
+import org.apache.drill.exec.store.StoragePluginOptimizerRule;
+import org.bson.BsonDocument;
+import org.bson.BsonString;
+import org.bson.Document;
+import org.bson.conversions.Bson;
+
+import java.io.IOException;
+import java.util.AbstractList;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+public class MongoPushDownAggregateForScan extends StoragePluginOptimizerRule {
+ public static final StoragePluginOptimizerRule INSTANCE = new MongoPushDownAggregateForScan(RelOptHelper.some(Aggregate.class, RelOptHelper.any(DrillScanRelBase.class)), "MongoPushDownAggregateForScan");
+ public static final StoragePluginOptimizerRule PROJ_INSTANCE = new MongoPushDownAggregateForScan(RelOptHelper.some(Aggregate.class, RelOptHelper.some(Project.class, RelOptHelper.any(DrillScanRelBase.class))), "MongoPushDownAggregateForScan_project");
+
+ public MongoPushDownAggregateForScan(RelOptRuleOperand operand, String desc) {
+ super(operand, desc);
+ }
+
+ static List<String> mongoFieldNames(final RelDataType rowType) {
+ return SqlValidatorUtil.uniquify(
+ new AbstractList<String>() {
+ @Override public String get(int index) {
+ final String name = rowType.getFieldList().get(index).getName();
+ return name.startsWith("$") ? "_" + name.substring(2) : name;
+ }
+
+ @Override public int size() {
+ return rowType.getFieldCount();
+ }
+ },
+ SqlValidatorUtil.EXPR_SUGGESTER, true);
+ }
+
+ static String maybeQuote(String s) {
+ if (!needsQuote(s)) {
+ return s;
+ }
+ return quote(s);
+ }
+
+ static String quote(String s) {
+ return "'" + s + "'"; // TODO: handle embedded quotes
+ }
+
+ private static boolean needsQuote(String s) {
+ for (int i = 0, n = s.length(); i < n; i++) {
+ char c = s.charAt(i);
+ if (!Character.isJavaIdentifierPart(c)
+ || c == '$') {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ Aggregate aggregate = call.rel(0);
+ DrillScanRelBase scan = call.rel(1);
+
+ MongoGroupScan groupScan = (MongoGroupScan) scan.getGroupScan();
+
+
+// implementor.visitChild(0, getInput());
+ List<String> list = new ArrayList<>();
+// List<BsonDocument> docList = new ArrayList<>();
+ final List<String> inNames =
+ mongoFieldNames(scan.getRowType());
+ final List<String> outNames = mongoFieldNames(aggregate.getRowType());
+ int i = 0;
+ if (aggregate.getGroupSet().cardinality() == 1) {
+ final String inName = inNames.get(aggregate.getGroupSet().nth(0));
+ list.add("_id: " + maybeQuote("$" + inName));
+// docList.add(new BsonDocument("_id", new BsonString(maybeQuote("$" + inName))));
+ ++i;
+ } else {
+ List<String> keys = new ArrayList<>();
+ for (int group : aggregate.getGroupSet()) {
+ final String inName = inNames.get(group);
+ keys.add(inName + ": " + quote("$" + inName));
+ ++i;
+ }
+ list.add("_id: " + Util.toString(keys, "{", ", ", "}"));
+// docList.add(new BsonDocument("_id", new BsonString(Util.toString(keys, "{", ", ", "}"))));
+ }
+ for (AggregateCall aggCall : aggregate.getAggCallList()) {
+ list.add(
+ maybeQuote(outNames.get(i++)) + ": "
+ + toMongo(aggCall.getAggregation(), inNames, aggCall.getArgList()));
+ }
+ List<Pair<String, String>> aggsList = new ArrayList<>();
+ aggsList.add(Pair.of(null, "{$group: " + Util.toString(list, "{", ", ", "}") + "}"));
+ final List<String> fixups;
+ if (aggregate.getGroupSet().cardinality() == 1) {
+ fixups = new AbstractList<String>() {
+ @Override public String get(int index) {
+ final String outName = outNames.get(index);
+ return maybeQuote(outName) + ": "
+ + maybeQuote("$" + (index == 0 ? "_id" : outName));
+ }
+
+ @Override public int size() {
+ return outNames.size();
+ }
+ };
+ } else {
+ fixups = new ArrayList<>();
+ fixups.add("_id: 0");
+ i = 0;
+ for (int group : aggregate.getGroupSet()) {
+ fixups.add(
+ maybeQuote(outNames.get(group))
+ + ": "
+ + maybeQuote("$_id." + outNames.get(group)));
+ ++i;
+ }
+ for (AggregateCall ignored : aggregate.getAggCallList()) {
+ final String outName = outNames.get(i++);
+ fixups.add(
+ maybeQuote(outName) + ": " + maybeQuote(
+ "$" + outName));
+ }
+ }
+ if (!aggregate.getGroupSet().isEmpty()) {
+ aggsList.add(Pair.of(null, "{$project: " + Util.toString(fixups, "{", ", ", "}") + "}"));
+ }
+
+ MongoScanSpec mongoScanSpec = aggregate(groupScan.getScanSpec(), Pair.right(aggsList));
+ try {
+ List<SchemaPath> columns = outNames.stream()
+ .map(SchemaPath::getSimplePath)
+ .collect(Collectors.toList());
+ MongoGroupScan mongoScanSpec123 = new MongoGroupScan(groupScan.getUserName(), groupScan.getStoragePlugin(),
+ mongoScanSpec, columns, groupScan.getMaxRecords());
+ call.transformTo(scan.copy(aggregate.getTraitSet(), mongoScanSpec123, aggregate.getRowType()));
+ } catch (IOException e) {
+ throw new DrillRuntimeException(e.getMessage(), e);
+ }
+ }
+
+ private static String toMongo(SqlAggFunction aggregation, List<String> inNames,
+ List<Integer> args) {
+ if (aggregation.getName().equals(SqlStdOperatorTable.COUNT.getName())) {
+ if (args.size() == 0) {
+// Aggregates.count()
+ return "{$sum: 1}";
+ } else {
+ assert args.size() == 1;
+// Arrays.asList(
+// Aggregates.match(Filters.eq("languages.name", "English")),
+// Aggregates.count())
+ final String inName = inNames.get(args.get(0));
+ return "{$sum: {$cond: [ {$eq: ["
+ + quote(inName)
+ + ", null]}, 0, 1]}}";
+ }
+ } else if (aggregation instanceof SqlSumAggFunction
+ || aggregation instanceof SqlSumEmptyIsZeroAggFunction) {
+ assert args.size() == 1;
+ final String inName = inNames.get(args.get(0));
+ return "{$sum: " + maybeQuote("$" + inName) + "}";
+ } else if (aggregation.getName().equals(SqlStdOperatorTable.MIN.getName())) {
+ assert args.size() == 1;
+ final String inName = inNames.get(args.get(0));
+ return "{$min: " + maybeQuote("$" + inName) + "}";
+ } else if (aggregation.getName().equals(SqlStdOperatorTable.MAX.getName())) {
+ assert args.size() == 1;
+ final String inName = inNames.get(args.get(0));
+ return "{$max: " + maybeQuote("$" + inName) + "}";
+ } else if (aggregation.getName().equals(SqlStdOperatorTable.AVG.getName())) {
+ assert args.size() == 1;
+ final String inName = inNames.get(args.get(0));
+ return "{$avg: " + maybeQuote("$" + inName) + "}";
+ } else {
+ throw new AssertionError("unknown aggregate " + aggregation);
+ }
+ }
+
+ private MongoScanSpec aggregate(MongoScanSpec scanSpec,
+ final List<String> operations) {
+ final List<Bson> list = new ArrayList<>();
+ for (String operation : operations) {
+ list.add(BsonDocument.parse(operation));
+ }
+ return new MongoScanSpec(scanSpec.getDbName(), scanSpec.getCollectionName(),
+ scanSpec.getFilters(), list);
+// final Function1<Document, Object> getter =
+// getter(fields);
+// return new AbstractEnumerable<Object>() {
+// @Override public Enumerator<Object> enumerator() {
+// final Iterator<Document> resultIterator;
+// try {
+// resultIterator = mongoDb.getCollection(scanSpec.getCollectionName())
+// .aggregate(list).iterator();
+// } catch (Exception e) {
+// throw new RuntimeException("While running MongoDB query "
+// + Util.toString(operations, "[", ",\n", "]"), e);
+// }
+// return new MongoEnumerator(resultIterator, getter);
+// }
+// };
+ }
+
+ static Function1<Document, Map> mapGetter() {
+ return a0 -> (Map) a0;
+ }
+
+ /** Returns a function that projects a single field. */
+ static Function1<Document, Object> singletonGetter(final String fieldName,
+ final Class fieldClass) {
+ return a0 -> convert(a0.get(fieldName), fieldClass);
+ }
+
+ /** Returns a function that projects fields.
+ *
+ * @param fields List of fields to project; or null to return map
+ */
+ static Function1<Document, Object[]> listGetter(
+ final List<Map.Entry<String, Class>> fields) {
+ return a0 -> {
+ Object[] objects = new Object[fields.size()];
+ for (int i = 0; i < fields.size(); i++) {
+ final Map.Entry<String, Class> field = fields.get(i);
+ final String name = field.getKey();
+ objects[i] = convert(a0.get(name), field.getValue());
+ }
+ return objects;
+ };
+ }
+
+ static Function1<Document, Object> getter(
+ List<Map.Entry<String, Class>> fields) {
+ //noinspection unchecked
+ return fields == null
+ ? (Function1) mapGetter()
+ : fields.size() == 1
+ ? singletonGetter(fields.get(0).getKey(), fields.get(0).getValue())
+ : (Function1) listGetter(fields);
+ }
+
+ @SuppressWarnings("JavaUtilDate")
+ private static Object convert(Object o, Class clazz) {
+ if (o == null) {
+ return null;
+ }
+ Primitive primitive = Primitive.of(clazz);
+ if (primitive != null) {
+ clazz = primitive.boxClass;
+ } else {
+ primitive = Primitive.ofBox(clazz);
+ }
+ if (clazz.isInstance(o)) {
+ return o;
+ }
+ if (o instanceof Date && primitive != null) {
+ o = ((Date) o).getTime() / DateTimeUtils.MILLIS_PER_DAY;
+ }
+ if (o instanceof Number && primitive != null) {
+ return primitive.number((Number) o);
+ }
+ return o;
+ }
+
+ //$addToSet
+ //$avg
+ //$first
+ //$last
+ //$max
+ //$min
+ //$mergeObjects
+ //$push
+ //$stdDevPop
+ //$stdDevSamp
+ //$sum
+}
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java
index 5e57890..b1c06e7 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java
@@ -76,7 +76,7 @@ public class MongoPushDownFilterForScan extends StoragePluginOptimizerRule {
}
newGroupsScan.setFilterPushedDown(true);
- RelNode newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan);
+ RelNode newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan, filter.getRowType());
if (mongoFilterBuilder.isAllExpressionsConverted()) {
/*
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java
index b06fe36..7c4f3f2 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java
@@ -18,6 +18,7 @@
package org.apache.drill.exec.store.mongo;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
@@ -25,6 +26,9 @@ import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.TimeUnit;
+import com.mongodb.client.FindIterable;
+import com.mongodb.client.model.Aggregates;
+import org.apache.commons.collections.CollectionUtils;
import org.apache.drill.common.exceptions.DrillRuntimeException;
import org.apache.drill.common.exceptions.ExecutionSetupException;
import org.apache.drill.common.expression.SchemaPath;
@@ -40,6 +44,7 @@ import org.apache.drill.exec.vector.complex.impl.VectorContainerWriter;
import org.bson.BsonDocument;
import org.bson.BsonDocumentReader;
import org.bson.Document;
+import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -64,6 +69,7 @@ public class MongoRecordReader extends AbstractRecordReader {
private VectorContainerWriter writer;
private Document filters;
+ private List<Bson> aggregates;
private final Document fields;
private final FragmentContext fragmentContext;
@@ -87,6 +93,7 @@ public class MongoRecordReader extends AbstractRecordReader {
fragmentContext = context;
this.plugin = plugin;
filters = new Document();
+ aggregates = subScanSpec.aggregates;
Map<String, List<Document>> mergedFilters = MongoUtils.mergeFilters(
subScanSpec.getMinFilters(), subScanSpec.getMaxFilters());
@@ -176,12 +183,24 @@ public class MongoRecordReader extends AbstractRecordReader {
logger.debug("Filters Applied : " + filters);
logger.debug("Fields Selected :" + fields);
- // Add limit to Mongo query
- if (maxRecords > 0) {
- logger.debug("Limit applied: {}", maxRecords);
- cursor = collection.find(filters).projection(fields).limit(maxRecords).batchSize(100).iterator();
+ if (CollectionUtils.isNotEmpty(aggregates)) {
+ List<Bson> operations = new ArrayList<>();
+ operations.add(Aggregates.match(filters));
+ operations.addAll(aggregates);
+ operations.add(Aggregates.project(fields));
+ if (maxRecords > 0) {
+ operations.add(Aggregates.limit(maxRecords));
+ }
+ cursor = collection.aggregate(operations).batchSize(100).iterator();
} else {
- cursor = collection.find(filters).projection(fields).batchSize(100).iterator();
+ // Add limit to Mongo query
+ FindIterable<BsonDocument> projection = collection.find(filters).projection(fields);
+ if (maxRecords > 0) {
+ logger.debug("Limit applied: {}", maxRecords);
+ projection = projection.limit(maxRecords);
+ }
+
+ cursor = projection.batchSize(100).iterator();
}
}
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java
index 5c56fcc..7ec1210 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java
@@ -21,13 +21,19 @@ import org.bson.Document;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
+import org.bson.conversions.Bson;
+
+import java.util.List;
+import java.util.StringJoiner;
public class MongoScanSpec {
- private String dbName;
- private String collectionName;
+ private final String dbName;
+ private final String collectionName;
private Document filters;
+ private List<Bson> aggregates;
+
@JsonCreator
public MongoScanSpec(@JsonProperty("dbName") String dbName,
@JsonProperty("collectionName") String collectionName) {
@@ -42,6 +48,14 @@ public class MongoScanSpec {
this.filters = filters;
}
+ public MongoScanSpec(String dbName, String collectionName,
+ Document filters, List<Bson> aggregates) {
+ this.dbName = dbName;
+ this.collectionName = collectionName;
+ this.filters = filters;
+ this.aggregates = aggregates;
+ }
+
public String getDbName() {
return dbName;
}
@@ -54,10 +68,17 @@ public class MongoScanSpec {
return filters;
}
+ public List<Bson> getAggregates() {
+ return aggregates;
+ }
+
@Override
public String toString() {
- return "MongoScanSpec [dbName=" + dbName + ", collectionName="
- + collectionName + ", filters=" + filters + "]";
+ return new StringJoiner(", ", MongoScanSpec.class.getSimpleName() + "[", "]")
+ .add("dbName='" + dbName + "'")
+ .add("collectionName='" + collectionName + "'")
+ .add("filters=" + filters)
+ .add("aggregates=" + aggregates)
+ .toString();
}
-
}
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java
index da55907..f6c3ac2 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java
@@ -25,16 +25,16 @@ import com.mongodb.client.MongoClient;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.client.MongoClients;
+import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.schema.SchemaPlus;
import org.apache.drill.common.JSONOptions;
import org.apache.drill.common.exceptions.DrillRuntimeException;
-import org.apache.drill.common.exceptions.ExecutionSetupException;
import org.apache.drill.exec.ops.OptimizerRulesContext;
import org.apache.drill.exec.physical.base.AbstractGroupScan;
+import org.apache.drill.exec.planner.PlannerPhase;
import org.apache.drill.exec.server.DrillbitContext;
import org.apache.drill.exec.store.AbstractStoragePlugin;
import org.apache.drill.exec.store.SchemaConfig;
-import org.apache.drill.exec.store.StoragePluginOptimizerRule;
import org.apache.drill.exec.store.mongo.schema.MongoSchemaFactory;
import org.apache.drill.common.logical.security.CredentialsProvider;
import org.apache.drill.exec.store.security.HadoopCredentialsProvider;
@@ -52,6 +52,7 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.URLEncoder;
+import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutionException;
@@ -68,7 +69,7 @@ public class MongoStoragePlugin extends AbstractStoragePlugin {
public MongoStoragePlugin(
MongoStoragePluginConfig mongoConfig,
DrillbitContext context,
- String name) throws ExecutionSetupException {
+ String name) {
super(context, name);
this.mongoConfig = mongoConfig;
String connection = addCredentialsFromCredentialsProvider(this.mongoConfig.getConnection(), name);
@@ -120,7 +121,7 @@ public class MongoStoragePlugin extends AbstractStoragePlugin {
}
@Override
- public void registerSchemas(SchemaConfig schemaConfig, SchemaPlus parent) throws IOException {
+ public void registerSchemas(SchemaConfig schemaConfig, SchemaPlus parent) {
schemaFactory.registerSchemas(schemaConfig, parent);
}
@@ -137,8 +138,19 @@ public class MongoStoragePlugin extends AbstractStoragePlugin {
}
@Override
- public Set<StoragePluginOptimizerRule> getPhysicalOptimizerRules(OptimizerRulesContext optimizerRulesContext) {
- return ImmutableSet.of(MongoPushDownFilterForScan.INSTANCE);
+ public Set<? extends RelOptRule> getOptimizerRules(OptimizerRulesContext optimizerContext, PlannerPhase phase) {
+ switch (phase) {
+ case PHYSICAL:
+ case LOGICAL:
+ return ImmutableSet.of(MongoPushDownFilterForScan.INSTANCE,
+ MongoPushDownAggregateForScan.INSTANCE);
+ case LOGICAL_PRUNE_AND_JOIN:
+ case LOGICAL_PRUNE:
+ case PARTITION_PRUNING:
+ case JOIN_PLANNING:
+ default:
+ return Collections.emptySet();
+ }
}
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java
index af13eb5..a32336d 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java
@@ -40,6 +40,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;
+import org.bson.conversions.Bson;
@JsonTypeName("mongo-shard-read")
public class MongoSubScan extends AbstractBase implements SubScan {
@@ -132,6 +133,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
protected int maxRecords;
protected Document filter;
+ protected List<Bson> aggregates;
@JsonCreator
public MongoSubScanSpec(@JsonProperty("dbName") String dbName,
@@ -140,6 +142,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
@JsonProperty("minFilters") Map<String, Object> minFilters,
@JsonProperty("maxFilters") Map<String, Object> maxFilters,
@JsonProperty("filters") Document filters,
+ @JsonProperty("aggregates") List<Bson> aggregates,
@JsonProperty("maxRecords") int maxRecords) {
this.dbName = dbName;
this.collectionName = collectionName;
@@ -147,6 +150,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
this.minFilters = minFilters;
this.maxFilters = maxFilters;
this.filter = filters;
+ this.aggregates = aggregates;
this.maxRecords = maxRecords;
}
@@ -215,6 +219,11 @@ public class MongoSubScan extends AbstractBase implements SubScan {
return this;
}
+ public MongoSubScanSpec setAggregates(List<Bson> aggregates) {
+ this.aggregates = aggregates;
+ return this;
+ }
+
@Override
public String toString() {
return new PlanStringBuilder(this)
@@ -224,6 +233,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
.field("minFilters", minFilters)
.field("maxFilters", maxFilters)
.field("filter", filter)
+ .field("aggregates", aggregates)
.field("maxRecords", maxRecords)
.toString();
diff --git a/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java b/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java
index 4b20ebc..d5316f9 100644
--- a/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java
+++ b/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java
@@ -105,4 +105,60 @@ public class TestMongoQueries extends MongoTestBase {
.expectsNumRecords(5)
.go();
}
+
+ @Test
+ public void testCountColumnPushDown() throws Exception {
+ String query = "select count(t.name) as c from mongo.%s.`%s` t";
+
+ queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION)
+ .planMatcher()
+ .exclude("Agg\\(")
+ .include("Scan\\(.*aggregates")
+ .match();
+
+ testBuilder()
+ .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION)
+ .unOrdered()
+ .baselineColumns("c")
+ .baselineValues(5)
+ .go();
+ }
+
+ @Test
+ public void testCountGroupByPushDown() throws Exception {
+ String query = "select count(t.id) as c, t.type from mongo.%s.`%s` t group by t.type";
+
+ queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION)
+ .planMatcher()
+ .exclude("Agg\\(")
+ .include("Scan\\(.*aggregates")
+ .match();
+
+ testBuilder()
+ .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION)
+ .unOrdered()
+ .baselineColumns("c", "type")
+ .baselineValues(5, "donut")
+ .go();
+ }
+
+ @Test
+ public void testCountColumnPushDownWithFilter() throws Exception {
+ String query = "select count(t.id) as c from mongo.%s.`%s` t where t.name = 'Cake'";
+
+ queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION)
+ .planMatcher()
+ .exclude("Agg\\(", "Filter")
+ .include("Scan\\(.*aggregates")
+ .match();
+
+ testBuilder()
+ .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION)
+ .unOrdered()
+ .baselineColumns("c")
+ .baselineValues(1)
+ .go();
+
+// queryBuilder().sql("select * from mongo.%s.`%s` t", DONUTS_DB, DONUTS_COLLECTION).printCsv();
+ }
}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
index 97b34e1..cb0fd3c 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
@@ -347,6 +347,7 @@ public enum PlannerPhase {
// RuleInstance.PROJECT_SET_OP_TRANSPOSE_RULE,
RuleInstance.PROJECT_WINDOW_TRANSPOSE_RULE,
DrillPushProjectIntoScanRule.INSTANCE,
+ DrillPushProjectIntoScanRule.LOGICAL_INSTANCE,
DrillPushProjectIntoScanRule.DRILL_LOGICAL_INSTANCE,
/*
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java
index fe67709..a307f93 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java
@@ -19,6 +19,8 @@ package org.apache.drill.exec.planner.common;
import java.io.IOException;
import java.util.List;
+
+import org.apache.calcite.rel.type.RelDataType;
import org.apache.drill.common.expression.SchemaPath;
import org.apache.drill.exec.physical.base.GroupScan;
import org.apache.drill.exec.planner.logical.DrillTable;
@@ -87,5 +89,5 @@ public abstract class DrillScanRelBase extends TableScan implements DrillRelNode
return planner.getCostFactory().makeCost(dRows, dCpu, dIo);
}
- public abstract DrillScanRelBase copy(RelTraitSet traitSet, GroupScan scan);
+ public abstract DrillScanRelBase copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType);
}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java
index 91875bb..d54bb42 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java
@@ -56,11 +56,16 @@ public class DrillPushProjectIntoScanRule extends RelOptRule {
}
};
- public static final RelOptRule DRILL_LOGICAL_INSTANCE =
+ public static final RelOptRule LOGICAL_INSTANCE =
new DrillPushProjectIntoScanRule(LogicalProject.class,
DrillScanRel.class,
"DrillPushProjectIntoScanRule:logical");
+ public static final RelOptRule DRILL_LOGICAL_INSTANCE =
+ new DrillPushProjectIntoScanRule(DrillProjectRel.class,
+ DrillScanRel.class,
+ "DrillPushProjectIntoScanRule:drill_logical");
+
public static final RelOptRule DRILL_PHYSICAL_INSTANCE =
new DrillPushProjectIntoScanRule(ProjectPrel.class,
ScanPrel.class,
@@ -167,11 +172,20 @@ public class DrillPushProjectIntoScanRule extends RelOptRule {
* @return new scan instance
*/
protected TableScan createScan(TableScan scan, ProjectPushInfo projectPushInfo) {
- return new DrillScanRel(scan.getCluster(),
- scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL),
- scan.getTable(),
- projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()),
- projectPushInfo.getFields());
+ if (scan instanceof DrillScanRel) {
+ return new DrillScanRel(scan.getCluster(),
+ scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL),
+ scan.getTable(),
+ ((DrillScanRel) scan).getGroupScan().clone(projectPushInfo.getFields()),
+ projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()),
+ projectPushInfo.getFields());
+ } else {
+ return new DrillScanRel(scan.getCluster(),
+ scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL),
+ scan.getTable(),
+ projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()),
+ projectPushInfo.getFields());
+ }
}
/**
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java
index 26ef4ea..bcd9792 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java
@@ -193,7 +193,7 @@ public class DrillScanRel extends DrillScanRelBase implements DrillRel {
}
@Override
- public DrillScanRel copy(RelTraitSet traitSet, GroupScan scan) {
- return new DrillScanRel(getCluster(), getTraitSet(), getTable(), scan, getRowType(), getColumns(), partitionFilterPushdown());
+ public DrillScanRel copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType) {
+ return new DrillScanRel(getCluster(), getTraitSet(), getTable(), scan, rowType, getColumns(), partitionFilterPushdown());
}
}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java
index 50996b9..1e0bdf4 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java
@@ -60,8 +60,8 @@ public class ScanPrel extends DrillScanRelBase implements LeafPrel, HasDistribut
}
@Override
- public ScanPrel copy(RelTraitSet traitSet, GroupScan scan) {
- return new ScanPrel(getCluster(), traitSet, scan, getRowType(), getTable());
+ public ScanPrel copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType) {
+ return new ScanPrel(getCluster(), traitSet, scan, rowType, getTable());
}
@Override
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
index 0b68014..9fd789c 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
@@ -54,6 +54,7 @@ public class StreamAggPrule extends AggPruleBase {
@Override
public void onMatch(RelOptRuleCall call) {
+
final DrillAggregateRel aggregate = call.rel(0);
RelNode input = aggregate.getInput();
final RelCollation collation = getCollation(aggregate);