You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by vo...@apache.org on 2021/08/05 17:23:14 UTC

[drill] 01/13: Initial changes

This is an automated email from the ASF dual-hosted git repository.

volodymyr pushed a commit to branch mongo
in repository https://gitbox.apache.org/repos/asf/drill.git

commit 40d6f1e56ad64c665c69db73efd1c51fba9cd42d
Author: Volodymyr Vysotskyi <vv...@gmail.com>
AuthorDate: Thu Jul 8 21:31:37 2021 +0300

    Initial changes
---
 .../store/druid/DruidPushDownFilterForScan.java    |   2 +-
 .../drill/exec/store/mongo/MongoGroupScan.java     |   3 +-
 .../store/mongo/MongoPushDownAggregateForScan.java | 301 +++++++++++++++++++++
 .../store/mongo/MongoPushDownFilterForScan.java    |   2 +-
 .../drill/exec/store/mongo/MongoRecordReader.java  |  29 +-
 .../drill/exec/store/mongo/MongoScanSpec.java      |  31 ++-
 .../drill/exec/store/mongo/MongoStoragePlugin.java |  24 +-
 .../drill/exec/store/mongo/MongoSubScan.java       |  10 +
 .../drill/exec/store/mongo/TestMongoQueries.java   |  56 ++++
 .../apache/drill/exec/planner/PlannerPhase.java    |   1 +
 .../exec/planner/common/DrillScanRelBase.java      |   4 +-
 .../logical/DrillPushProjectIntoScanRule.java      |  26 +-
 .../drill/exec/planner/logical/DrillScanRel.java   |   4 +-
 .../drill/exec/planner/physical/ScanPrel.java      |   4 +-
 .../exec/planner/physical/StreamAggPrule.java      |   1 +
 15 files changed, 468 insertions(+), 30 deletions(-)

diff --git a/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java b/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java
index 65d95aa..2c5fcee 100644
--- a/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java
+++ b/contrib/storage-druid/src/main/java/org/apache/drill/exec/store/druid/DruidPushDownFilterForScan.java
@@ -73,7 +73,7 @@ public class DruidPushDownFilterForScan extends StoragePluginOptimizerRule {
             groupScan.getMaxRecordsToRead());
     newGroupsScan.setFilterPushedDown(true);
 
-    ScanPrel newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan);
+    ScanPrel newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan, filter.getRowType());
     if (druidFilterBuilder.isAllExpressionsConverted()) {
       /*
        * Since we could convert the entire filter condition expression into a
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java
index 8b57012..6662e8c 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoGroupScan.java
@@ -466,7 +466,8 @@ public class MongoGroupScan extends AbstractGroupScan implements
         .setMinFilters(chunkInfo.getMinFilters())
         .setMaxFilters(chunkInfo.getMaxFilters())
         .setMaxRecords(maxRecords)
-        .setFilter(scanSpec.getFilters());
+        .setFilter(scanSpec.getFilters())
+        .setAggregates(scanSpec.getAggregates());
   }
 
   @Override
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java
new file mode 100644
index 0000000..f7e1a00
--- /dev/null
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownAggregateForScan.java
@@ -0,0 +1,301 @@
+package org.apache.drill.exec.store.mongo;
+
+import org.apache.calcite.avatica.util.DateTimeUtils;
+import org.apache.calcite.linq4j.function.Function1;
+import org.apache.calcite.linq4j.tree.Primitive;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptRuleOperand;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.fun.SqlSumAggFunction;
+import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
+import org.apache.calcite.sql.validate.SqlValidatorUtil;
+import org.apache.calcite.util.Pair;
+import org.apache.calcite.util.Util;
+import org.apache.drill.common.exceptions.DrillRuntimeException;
+import org.apache.drill.common.expression.SchemaPath;
+import org.apache.drill.exec.planner.common.DrillScanRelBase;
+import org.apache.drill.exec.planner.logical.RelOptHelper;
+import org.apache.drill.exec.store.StoragePluginOptimizerRule;
+import org.bson.BsonDocument;
+import org.bson.BsonString;
+import org.bson.Document;
+import org.bson.conversions.Bson;
+
+import java.io.IOException;
+import java.util.AbstractList;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+public class MongoPushDownAggregateForScan extends StoragePluginOptimizerRule {
+  public static final StoragePluginOptimizerRule INSTANCE = new MongoPushDownAggregateForScan(RelOptHelper.some(Aggregate.class, RelOptHelper.any(DrillScanRelBase.class)), "MongoPushDownAggregateForScan");
+  public static final StoragePluginOptimizerRule PROJ_INSTANCE = new MongoPushDownAggregateForScan(RelOptHelper.some(Aggregate.class, RelOptHelper.some(Project.class, RelOptHelper.any(DrillScanRelBase.class))), "MongoPushDownAggregateForScan_project");
+
+  public MongoPushDownAggregateForScan(RelOptRuleOperand operand, String desc) {
+    super(operand, desc);
+  }
+
+  static List<String> mongoFieldNames(final RelDataType rowType) {
+    return SqlValidatorUtil.uniquify(
+        new AbstractList<String>() {
+          @Override public String get(int index) {
+            final String name = rowType.getFieldList().get(index).getName();
+            return name.startsWith("$") ? "_" + name.substring(2) : name;
+          }
+
+          @Override public int size() {
+            return rowType.getFieldCount();
+          }
+        },
+        SqlValidatorUtil.EXPR_SUGGESTER, true);
+  }
+
+  static String maybeQuote(String s) {
+    if (!needsQuote(s)) {
+      return s;
+    }
+    return quote(s);
+  }
+
+  static String quote(String s) {
+    return "'" + s + "'"; // TODO: handle embedded quotes
+  }
+
+  private static boolean needsQuote(String s) {
+    for (int i = 0, n = s.length(); i < n; i++) {
+      char c = s.charAt(i);
+      if (!Character.isJavaIdentifierPart(c)
+          || c == '$') {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    Aggregate aggregate = call.rel(0);
+    DrillScanRelBase scan = call.rel(1);
+
+    MongoGroupScan groupScan = (MongoGroupScan) scan.getGroupScan();
+
+
+//    implementor.visitChild(0, getInput());
+    List<String> list = new ArrayList<>();
+//    List<BsonDocument> docList = new ArrayList<>();
+    final List<String> inNames =
+        mongoFieldNames(scan.getRowType());
+    final List<String> outNames = mongoFieldNames(aggregate.getRowType());
+    int i = 0;
+    if (aggregate.getGroupSet().cardinality() == 1) {
+      final String inName = inNames.get(aggregate.getGroupSet().nth(0));
+      list.add("_id: " + maybeQuote("$" + inName));
+//      docList.add(new BsonDocument("_id", new BsonString(maybeQuote("$" + inName))));
+      ++i;
+    } else {
+      List<String> keys = new ArrayList<>();
+      for (int group : aggregate.getGroupSet()) {
+        final String inName = inNames.get(group);
+        keys.add(inName + ": " + quote("$" + inName));
+        ++i;
+      }
+      list.add("_id: " + Util.toString(keys, "{", ", ", "}"));
+//      docList.add(new BsonDocument("_id", new BsonString(Util.toString(keys, "{", ", ", "}"))));
+    }
+    for (AggregateCall aggCall : aggregate.getAggCallList()) {
+      list.add(
+          maybeQuote(outNames.get(i++)) + ": "
+              + toMongo(aggCall.getAggregation(), inNames, aggCall.getArgList()));
+    }
+    List<Pair<String, String>> aggsList = new ArrayList<>();
+    aggsList.add(Pair.of(null, "{$group: " + Util.toString(list, "{", ", ", "}") + "}"));
+    final List<String> fixups;
+    if (aggregate.getGroupSet().cardinality() == 1) {
+      fixups = new AbstractList<String>() {
+        @Override public String get(int index) {
+          final String outName = outNames.get(index);
+          return maybeQuote(outName) + ": "
+              + maybeQuote("$" + (index == 0 ? "_id" : outName));
+        }
+
+        @Override public int size() {
+          return outNames.size();
+        }
+      };
+    } else {
+      fixups = new ArrayList<>();
+      fixups.add("_id: 0");
+      i = 0;
+      for (int group : aggregate.getGroupSet()) {
+        fixups.add(
+            maybeQuote(outNames.get(group))
+                + ": "
+                + maybeQuote("$_id." + outNames.get(group)));
+        ++i;
+      }
+      for (AggregateCall ignored : aggregate.getAggCallList()) {
+        final String outName = outNames.get(i++);
+        fixups.add(
+            maybeQuote(outName) + ": " + maybeQuote(
+                "$" + outName));
+      }
+    }
+    if (!aggregate.getGroupSet().isEmpty()) {
+      aggsList.add(Pair.of(null, "{$project: " + Util.toString(fixups, "{", ", ", "}") + "}"));
+    }
+
+    MongoScanSpec mongoScanSpec = aggregate(groupScan.getScanSpec(), Pair.right(aggsList));
+    try {
+      List<SchemaPath> columns = outNames.stream()
+          .map(SchemaPath::getSimplePath)
+          .collect(Collectors.toList());
+      MongoGroupScan mongoScanSpec123 = new MongoGroupScan(groupScan.getUserName(), groupScan.getStoragePlugin(),
+          mongoScanSpec, columns, groupScan.getMaxRecords());
+      call.transformTo(scan.copy(aggregate.getTraitSet(), mongoScanSpec123, aggregate.getRowType()));
+    } catch (IOException e) {
+      throw new DrillRuntimeException(e.getMessage(), e);
+    }
+  }
+
+  private static String toMongo(SqlAggFunction aggregation, List<String> inNames,
+      List<Integer> args) {
+    if (aggregation.getName().equals(SqlStdOperatorTable.COUNT.getName())) {
+      if (args.size() == 0) {
+//        Aggregates.count()
+        return "{$sum: 1}";
+      } else {
+        assert args.size() == 1;
+//        Arrays.asList(
+//            Aggregates.match(Filters.eq("languages.name", "English")),
+//            Aggregates.count())
+        final String inName = inNames.get(args.get(0));
+        return "{$sum: {$cond: [ {$eq: ["
+            + quote(inName)
+            + ", null]}, 0, 1]}}";
+      }
+    } else if (aggregation instanceof SqlSumAggFunction
+        || aggregation instanceof SqlSumEmptyIsZeroAggFunction) {
+      assert args.size() == 1;
+      final String inName = inNames.get(args.get(0));
+      return "{$sum: " + maybeQuote("$" + inName) + "}";
+    } else if (aggregation.getName().equals(SqlStdOperatorTable.MIN.getName())) {
+      assert args.size() == 1;
+      final String inName = inNames.get(args.get(0));
+      return "{$min: " + maybeQuote("$" + inName) + "}";
+    } else if (aggregation.getName().equals(SqlStdOperatorTable.MAX.getName())) {
+      assert args.size() == 1;
+      final String inName = inNames.get(args.get(0));
+      return "{$max: " + maybeQuote("$" + inName) + "}";
+    } else if (aggregation.getName().equals(SqlStdOperatorTable.AVG.getName())) {
+      assert args.size() == 1;
+      final String inName = inNames.get(args.get(0));
+      return "{$avg: " + maybeQuote("$" + inName) + "}";
+    } else {
+      throw new AssertionError("unknown aggregate " + aggregation);
+    }
+  }
+
+  private MongoScanSpec aggregate(MongoScanSpec scanSpec,
+      final List<String> operations) {
+    final List<Bson> list = new ArrayList<>();
+    for (String operation : operations) {
+      list.add(BsonDocument.parse(operation));
+    }
+    return new MongoScanSpec(scanSpec.getDbName(), scanSpec.getCollectionName(),
+        scanSpec.getFilters(), list);
+//    final Function1<Document, Object> getter =
+//        getter(fields);
+//    return new AbstractEnumerable<Object>() {
+//      @Override public Enumerator<Object> enumerator() {
+//        final Iterator<Document> resultIterator;
+//        try {
+//          resultIterator = mongoDb.getCollection(scanSpec.getCollectionName())
+//              .aggregate(list).iterator();
+//        } catch (Exception e) {
+//          throw new RuntimeException("While running MongoDB query "
+//              + Util.toString(operations, "[", ",\n", "]"), e);
+//        }
+//        return new MongoEnumerator(resultIterator, getter);
+//      }
+//    };
+  }
+
+  static Function1<Document, Map> mapGetter() {
+    return a0 -> (Map) a0;
+  }
+
+  /** Returns a function that projects a single field. */
+  static Function1<Document, Object> singletonGetter(final String fieldName,
+      final Class fieldClass) {
+    return a0 -> convert(a0.get(fieldName), fieldClass);
+  }
+
+  /** Returns a function that projects fields.
+   *
+   * @param fields List of fields to project; or null to return map
+   */
+  static Function1<Document, Object[]> listGetter(
+      final List<Map.Entry<String, Class>> fields) {
+    return a0 -> {
+      Object[] objects = new Object[fields.size()];
+      for (int i = 0; i < fields.size(); i++) {
+        final Map.Entry<String, Class> field = fields.get(i);
+        final String name = field.getKey();
+        objects[i] = convert(a0.get(name), field.getValue());
+      }
+      return objects;
+    };
+  }
+
+  static Function1<Document, Object> getter(
+      List<Map.Entry<String, Class>> fields) {
+    //noinspection unchecked
+    return fields == null
+        ? (Function1) mapGetter()
+        : fields.size() == 1
+        ? singletonGetter(fields.get(0).getKey(), fields.get(0).getValue())
+        : (Function1) listGetter(fields);
+  }
+
+  @SuppressWarnings("JavaUtilDate")
+  private static Object convert(Object o, Class clazz) {
+    if (o == null) {
+      return null;
+    }
+    Primitive primitive = Primitive.of(clazz);
+    if (primitive != null) {
+      clazz = primitive.boxClass;
+    } else {
+      primitive = Primitive.ofBox(clazz);
+    }
+    if (clazz.isInstance(o)) {
+      return o;
+    }
+    if (o instanceof Date && primitive != null) {
+      o = ((Date) o).getTime() / DateTimeUtils.MILLIS_PER_DAY;
+    }
+    if (o instanceof Number && primitive != null) {
+      return primitive.number((Number) o);
+    }
+    return o;
+  }
+
+  //$addToSet
+  //$avg
+  //$first
+  //$last
+  //$max
+  //$min
+  //$mergeObjects
+  //$push
+  //$stdDevPop
+  //$stdDevSamp
+  //$sum
+}
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java
index 5e57890..b1c06e7 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoPushDownFilterForScan.java
@@ -76,7 +76,7 @@ public class MongoPushDownFilterForScan extends StoragePluginOptimizerRule {
     }
     newGroupsScan.setFilterPushedDown(true);
 
-    RelNode newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan);
+    RelNode newScanPrel = scan.copy(filter.getTraitSet(), newGroupsScan, filter.getRowType());
 
     if (mongoFilterBuilder.isAllExpressionsConverted()) {
       /*
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java
index b06fe36..7c4f3f2 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoRecordReader.java
@@ -18,6 +18,7 @@
 package org.apache.drill.exec.store.mongo;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
@@ -25,6 +26,9 @@ import java.util.Map.Entry;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
+import com.mongodb.client.FindIterable;
+import com.mongodb.client.model.Aggregates;
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.drill.common.exceptions.DrillRuntimeException;
 import org.apache.drill.common.exceptions.ExecutionSetupException;
 import org.apache.drill.common.expression.SchemaPath;
@@ -40,6 +44,7 @@ import org.apache.drill.exec.vector.complex.impl.VectorContainerWriter;
 import org.bson.BsonDocument;
 import org.bson.BsonDocumentReader;
 import org.bson.Document;
+import org.bson.conversions.Bson;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -64,6 +69,7 @@ public class MongoRecordReader extends AbstractRecordReader {
   private VectorContainerWriter writer;
 
   private Document filters;
+  private List<Bson> aggregates;
   private final Document fields;
 
   private final FragmentContext fragmentContext;
@@ -87,6 +93,7 @@ public class MongoRecordReader extends AbstractRecordReader {
     fragmentContext = context;
     this.plugin = plugin;
     filters = new Document();
+    aggregates = subScanSpec.aggregates;
     Map<String, List<Document>> mergedFilters = MongoUtils.mergeFilters(
         subScanSpec.getMinFilters(), subScanSpec.getMaxFilters());
 
@@ -176,12 +183,24 @@ public class MongoRecordReader extends AbstractRecordReader {
       logger.debug("Filters Applied : " + filters);
       logger.debug("Fields Selected :" + fields);
 
-      // Add limit to Mongo query
-      if (maxRecords > 0) {
-        logger.debug("Limit applied: {}", maxRecords);
-        cursor = collection.find(filters).projection(fields).limit(maxRecords).batchSize(100).iterator();
+      if (CollectionUtils.isNotEmpty(aggregates)) {
+        List<Bson> operations = new ArrayList<>();
+        operations.add(Aggregates.match(filters));
+        operations.addAll(aggregates);
+        operations.add(Aggregates.project(fields));
+        if (maxRecords > 0) {
+          operations.add(Aggregates.limit(maxRecords));
+        }
+        cursor = collection.aggregate(operations).batchSize(100).iterator();
       } else {
-        cursor = collection.find(filters).projection(fields).batchSize(100).iterator();
+        // Add limit to Mongo query
+        FindIterable<BsonDocument> projection = collection.find(filters).projection(fields);
+        if (maxRecords > 0) {
+          logger.debug("Limit applied: {}", maxRecords);
+          projection = projection.limit(maxRecords);
+        }
+
+        cursor = projection.batchSize(100).iterator();
       }
     }
 
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java
index 5c56fcc..7ec1210 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoScanSpec.java
@@ -21,13 +21,19 @@ import org.bson.Document;
 
 import com.fasterxml.jackson.annotation.JsonCreator;
 import com.fasterxml.jackson.annotation.JsonProperty;
+import org.bson.conversions.Bson;
+
+import java.util.List;
+import java.util.StringJoiner;
 
 public class MongoScanSpec {
-  private String dbName;
-  private String collectionName;
+  private final String dbName;
+  private final String collectionName;
 
   private Document filters;
 
+  private List<Bson> aggregates;
+
   @JsonCreator
   public MongoScanSpec(@JsonProperty("dbName") String dbName,
       @JsonProperty("collectionName") String collectionName) {
@@ -42,6 +48,14 @@ public class MongoScanSpec {
     this.filters = filters;
   }
 
+  public MongoScanSpec(String dbName, String collectionName,
+      Document filters, List<Bson> aggregates) {
+    this.dbName = dbName;
+    this.collectionName = collectionName;
+    this.filters = filters;
+    this.aggregates = aggregates;
+  }
+
   public String getDbName() {
     return dbName;
   }
@@ -54,10 +68,17 @@ public class MongoScanSpec {
     return filters;
   }
 
+  public List<Bson> getAggregates() {
+    return aggregates;
+  }
+
   @Override
   public String toString() {
-    return "MongoScanSpec [dbName=" + dbName + ", collectionName="
-        + collectionName + ", filters=" + filters + "]";
+    return new StringJoiner(", ", MongoScanSpec.class.getSimpleName() + "[", "]")
+        .add("dbName='" + dbName + "'")
+        .add("collectionName='" + collectionName + "'")
+        .add("filters=" + filters)
+        .add("aggregates=" + aggregates)
+        .toString();
   }
-
 }
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java
index da55907..f6c3ac2 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoStoragePlugin.java
@@ -25,16 +25,16 @@ import com.mongodb.client.MongoClient;
 import com.mongodb.MongoCredential;
 import com.mongodb.ServerAddress;
 import com.mongodb.client.MongoClients;
+import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.schema.SchemaPlus;
 import org.apache.drill.common.JSONOptions;
 import org.apache.drill.common.exceptions.DrillRuntimeException;
-import org.apache.drill.common.exceptions.ExecutionSetupException;
 import org.apache.drill.exec.ops.OptimizerRulesContext;
 import org.apache.drill.exec.physical.base.AbstractGroupScan;
+import org.apache.drill.exec.planner.PlannerPhase;
 import org.apache.drill.exec.server.DrillbitContext;
 import org.apache.drill.exec.store.AbstractStoragePlugin;
 import org.apache.drill.exec.store.SchemaConfig;
-import org.apache.drill.exec.store.StoragePluginOptimizerRule;
 import org.apache.drill.exec.store.mongo.schema.MongoSchemaFactory;
 import org.apache.drill.common.logical.security.CredentialsProvider;
 import org.apache.drill.exec.store.security.HadoopCredentialsProvider;
@@ -52,6 +52,7 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.net.URLEncoder;
+import java.util.Collections;
 import java.util.List;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
@@ -68,7 +69,7 @@ public class MongoStoragePlugin extends AbstractStoragePlugin {
   public MongoStoragePlugin(
       MongoStoragePluginConfig mongoConfig,
       DrillbitContext context,
-      String name) throws ExecutionSetupException {
+      String name) {
     super(context, name);
     this.mongoConfig = mongoConfig;
     String connection = addCredentialsFromCredentialsProvider(this.mongoConfig.getConnection(), name);
@@ -120,7 +121,7 @@ public class MongoStoragePlugin extends AbstractStoragePlugin {
   }
 
   @Override
-  public void registerSchemas(SchemaConfig schemaConfig, SchemaPlus parent) throws IOException {
+  public void registerSchemas(SchemaConfig schemaConfig, SchemaPlus parent) {
     schemaFactory.registerSchemas(schemaConfig, parent);
   }
 
@@ -137,8 +138,19 @@ public class MongoStoragePlugin extends AbstractStoragePlugin {
   }
 
   @Override
-  public Set<StoragePluginOptimizerRule> getPhysicalOptimizerRules(OptimizerRulesContext optimizerRulesContext) {
-    return ImmutableSet.of(MongoPushDownFilterForScan.INSTANCE);
+  public Set<? extends RelOptRule> getOptimizerRules(OptimizerRulesContext optimizerContext, PlannerPhase phase) {
+    switch (phase) {
+      case PHYSICAL:
+      case LOGICAL:
+        return ImmutableSet.of(MongoPushDownFilterForScan.INSTANCE,
+            MongoPushDownAggregateForScan.INSTANCE);
+      case LOGICAL_PRUNE_AND_JOIN:
+      case LOGICAL_PRUNE:
+      case PARTITION_PRUNING:
+      case JOIN_PLANNING:
+      default:
+        return Collections.emptySet();
+    }
   }
 
 
diff --git a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java
index af13eb5..a32336d 100644
--- a/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java
+++ b/contrib/storage-mongo/src/main/java/org/apache/drill/exec/store/mongo/MongoSubScan.java
@@ -40,6 +40,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import com.fasterxml.jackson.annotation.JsonTypeName;
 import org.apache.drill.shaded.guava.com.google.common.base.Preconditions;
+import org.bson.conversions.Bson;
 
 @JsonTypeName("mongo-shard-read")
 public class MongoSubScan extends AbstractBase implements SubScan {
@@ -132,6 +133,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
     protected int maxRecords;
 
     protected Document filter;
+    protected List<Bson> aggregates;
 
     @JsonCreator
     public MongoSubScanSpec(@JsonProperty("dbName") String dbName,
@@ -140,6 +142,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
         @JsonProperty("minFilters") Map<String, Object> minFilters,
         @JsonProperty("maxFilters") Map<String, Object> maxFilters,
         @JsonProperty("filters") Document filters,
+        @JsonProperty("aggregates") List<Bson> aggregates,
         @JsonProperty("maxRecords") int maxRecords) {
       this.dbName = dbName;
       this.collectionName = collectionName;
@@ -147,6 +150,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
       this.minFilters = minFilters;
       this.maxFilters = maxFilters;
       this.filter = filters;
+      this.aggregates = aggregates;
       this.maxRecords = maxRecords;
     }
 
@@ -215,6 +219,11 @@ public class MongoSubScan extends AbstractBase implements SubScan {
       return this;
     }
 
+    public MongoSubScanSpec setAggregates(List<Bson> aggregates) {
+      this.aggregates = aggregates;
+      return this;
+    }
+
     @Override
     public String toString() {
       return new PlanStringBuilder(this)
@@ -224,6 +233,7 @@ public class MongoSubScan extends AbstractBase implements SubScan {
         .field("minFilters", minFilters)
         .field("maxFilters", maxFilters)
         .field("filter", filter)
+        .field("aggregates", aggregates)
         .field("maxRecords", maxRecords)
         .toString();
 
diff --git a/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java b/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java
index 4b20ebc..d5316f9 100644
--- a/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java
+++ b/contrib/storage-mongo/src/test/java/org/apache/drill/exec/store/mongo/TestMongoQueries.java
@@ -105,4 +105,60 @@ public class TestMongoQueries extends MongoTestBase {
         .expectsNumRecords(5)
         .go();
   }
+
+  @Test
+  public void testCountColumnPushDown() throws Exception {
+    String query = "select count(t.name) as c from mongo.%s.`%s` t";
+
+    queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION)
+        .planMatcher()
+        .exclude("Agg\\(")
+        .include("Scan\\(.*aggregates")
+        .match();
+
+    testBuilder()
+        .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION)
+        .unOrdered()
+        .baselineColumns("c")
+        .baselineValues(5)
+        .go();
+  }
+
+  @Test
+  public void testCountGroupByPushDown() throws Exception {
+    String query = "select count(t.id) as c, t.type from mongo.%s.`%s` t group by t.type";
+
+    queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION)
+        .planMatcher()
+        .exclude("Agg\\(")
+        .include("Scan\\(.*aggregates")
+        .match();
+
+    testBuilder()
+        .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION)
+        .unOrdered()
+        .baselineColumns("c", "type")
+        .baselineValues(5, "donut")
+        .go();
+  }
+
+  @Test
+  public void testCountColumnPushDownWithFilter() throws Exception {
+    String query = "select count(t.id) as c from mongo.%s.`%s` t where t.name = 'Cake'";
+
+    queryBuilder().sql(query, DONUTS_DB, DONUTS_COLLECTION)
+        .planMatcher()
+        .exclude("Agg\\(", "Filter")
+        .include("Scan\\(.*aggregates")
+        .match();
+
+    testBuilder()
+        .sqlQuery(query, DONUTS_DB, DONUTS_COLLECTION)
+        .unOrdered()
+        .baselineColumns("c")
+        .baselineValues(1)
+        .go();
+
+//    queryBuilder().sql("select * from mongo.%s.`%s` t", DONUTS_DB, DONUTS_COLLECTION).printCsv();
+  }
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
index 97b34e1..cb0fd3c 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/PlannerPhase.java
@@ -347,6 +347,7 @@ public enum PlannerPhase {
       // RuleInstance.PROJECT_SET_OP_TRANSPOSE_RULE,
       RuleInstance.PROJECT_WINDOW_TRANSPOSE_RULE,
       DrillPushProjectIntoScanRule.INSTANCE,
+      DrillPushProjectIntoScanRule.LOGICAL_INSTANCE,
       DrillPushProjectIntoScanRule.DRILL_LOGICAL_INSTANCE,
 
       /*
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java
index fe67709..a307f93 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/common/DrillScanRelBase.java
@@ -19,6 +19,8 @@ package org.apache.drill.exec.planner.common;
 
 import java.io.IOException;
 import java.util.List;
+
+import org.apache.calcite.rel.type.RelDataType;
 import org.apache.drill.common.expression.SchemaPath;
 import org.apache.drill.exec.physical.base.GroupScan;
 import org.apache.drill.exec.planner.logical.DrillTable;
@@ -87,5 +89,5 @@ public abstract class DrillScanRelBase extends TableScan implements DrillRelNode
     return planner.getCostFactory().makeCost(dRows, dCpu, dIo);
   }
 
-  public abstract DrillScanRelBase copy(RelTraitSet traitSet, GroupScan scan);
+  public abstract DrillScanRelBase copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType);
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java
index 91875bb..d54bb42 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillPushProjectIntoScanRule.java
@@ -56,11 +56,16 @@ public class DrillPushProjectIntoScanRule extends RelOptRule {
         }
       };
 
-  public static final RelOptRule DRILL_LOGICAL_INSTANCE =
+  public static final RelOptRule LOGICAL_INSTANCE =
       new DrillPushProjectIntoScanRule(LogicalProject.class,
           DrillScanRel.class,
           "DrillPushProjectIntoScanRule:logical");
 
+  public static final RelOptRule DRILL_LOGICAL_INSTANCE =
+      new DrillPushProjectIntoScanRule(DrillProjectRel.class,
+          DrillScanRel.class,
+          "DrillPushProjectIntoScanRule:drill_logical");
+
   public static final RelOptRule DRILL_PHYSICAL_INSTANCE =
       new DrillPushProjectIntoScanRule(ProjectPrel.class,
           ScanPrel.class,
@@ -167,11 +172,20 @@ public class DrillPushProjectIntoScanRule extends RelOptRule {
    * @return new scan instance
    */
   protected TableScan createScan(TableScan scan, ProjectPushInfo projectPushInfo) {
-    return new DrillScanRel(scan.getCluster(),
-        scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL),
-        scan.getTable(),
-        projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()),
-        projectPushInfo.getFields());
+    if (scan instanceof DrillScanRel) {
+      return new DrillScanRel(scan.getCluster(),
+          scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL),
+          scan.getTable(),
+          ((DrillScanRel) scan).getGroupScan().clone(projectPushInfo.getFields()),
+          projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()),
+          projectPushInfo.getFields());
+    } else {
+      return new DrillScanRel(scan.getCluster(),
+          scan.getTraitSet().plus(DrillRel.DRILL_LOGICAL),
+          scan.getTable(),
+          projectPushInfo.createNewRowType(scan.getCluster().getTypeFactory()),
+          projectPushInfo.getFields());
+    }
   }
 
   /**
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java
index 26ef4ea..bcd9792 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillScanRel.java
@@ -193,7 +193,7 @@ public class DrillScanRel extends DrillScanRelBase implements DrillRel {
   }
 
   @Override
-  public DrillScanRel copy(RelTraitSet traitSet, GroupScan scan) {
-    return new DrillScanRel(getCluster(), getTraitSet(), getTable(), scan, getRowType(), getColumns(), partitionFilterPushdown());
+  public DrillScanRel copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType) {
+    return new DrillScanRel(getCluster(), getTraitSet(), getTable(), scan, rowType, getColumns(), partitionFilterPushdown());
   }
 }
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java
index 50996b9..1e0bdf4 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/ScanPrel.java
@@ -60,8 +60,8 @@ public class ScanPrel extends DrillScanRelBase implements LeafPrel, HasDistribut
   }
 
   @Override
-  public ScanPrel copy(RelTraitSet traitSet, GroupScan scan) {
-    return new ScanPrel(getCluster(), traitSet, scan, getRowType(), getTable());
+  public ScanPrel copy(RelTraitSet traitSet, GroupScan scan, RelDataType rowType) {
+    return new ScanPrel(getCluster(), traitSet, scan, rowType, getTable());
   }
 
   @Override
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
index 0b68014..9fd789c 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/planner/physical/StreamAggPrule.java
@@ -54,6 +54,7 @@ public class StreamAggPrule extends AggPruleBase {
 
   @Override
   public void onMatch(RelOptRuleCall call) {
+
     final DrillAggregateRel aggregate = call.rel(0);
     RelNode input = aggregate.getInput();
     final RelCollation collation = getCollation(aggregate);