You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by ie...@apache.org on 2019/03/27 20:10:59 UTC

[beam] branch master updated: [BEAM-6241] Add support for aggreagates using withQueryFn to MongoDbIO

This is an automated email from the ASF dual-hosted git repository.

iemejia pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 7e8d3c2  [BEAM-6241] Add support for aggreagates using withQueryFn to MongoDbIO
     new c0b0905  Merge pull request #7293: [BEAM-6241[ Add support for aggreagates using withQueryFn to MongoDbIO
7e8d3c2 is described below

commit 7e8d3c20dcd3b19dfa310b6c072c9a0b64f3da54
Author: Ahmed El.Hussaini <ae...@gmail.com>
AuthorDate: Sun Dec 16 13:25:06 2018 -0500

    [BEAM-6241] Add support for aggreagates using withQueryFn to MongoDbIO
    
    - Added support to limit results.
    - Abstracted MongoDB query builder into a separate class.
    - Cleaned the BoundedMongoDbReader start method.
    - Added support to pass ObjectId as string.
    - Separated MongoDB's find and aggregation into two separate query classes.
    - Utilized SerializableFunction in both query builders.
    - [BEAM-4567] Fix splitting with bucket auto (Fixes use of Atlas MongoDB
---
 .../beam/sdk/io/mongodb/AggregationQuery.java      |  76 +++++
 .../org/apache/beam/sdk/io/mongodb/FindQuery.java  | 109 ++++++++
 .../org/apache/beam/sdk/io/mongodb/MongoDbIO.java  | 306 ++++++++++++++-------
 .../org/apache/beam/sdk/io/mongodb/SSLUtils.java   |   6 +-
 .../apache/beam/sdk/io/mongodb/MongoDbIOTest.java  | 114 ++++++--
 5 files changed, 485 insertions(+), 126 deletions(-)

diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/AggregationQuery.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/AggregationQuery.java
new file mode 100644
index 0000000..28de393
--- /dev/null
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/AggregationQuery.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.mongodb;
+
+import com.google.auto.value.AutoValue;
+import com.mongodb.client.MongoCollection;
+import com.mongodb.client.MongoCursor;
+import com.mongodb.lang.Nullable;
+import java.util.ArrayList;
+import java.util.List;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.bson.BsonDocument;
+import org.bson.Document;
+
+/** Builds a MongoDB AggregateIterable object. */
+@Experimental(Experimental.Kind.SOURCE_SINK)
+@AutoValue
+public abstract class AggregationQuery
+    implements SerializableFunction<MongoCollection<Document>, MongoCursor<Document>> {
+
+  abstract List<BsonDocument> mongoDbPipeline();
+
+  @Nullable
+  abstract BsonDocument bucket();
+
+  private static Builder builder() {
+    return new AutoValue_AggregationQuery.Builder().setMongoDbPipeline(new ArrayList<>());
+  }
+
+  abstract Builder toBuilder();
+
+  public static AggregationQuery create() {
+    return builder().build();
+  }
+
+  @AutoValue.Builder
+  abstract static class Builder {
+    abstract Builder setMongoDbPipeline(List<BsonDocument> mongoDbPipeline);
+
+    abstract Builder setBucket(BsonDocument bucket);
+
+    abstract AggregationQuery build();
+  }
+
+  public AggregationQuery withMongoDbPipeline(List<BsonDocument> mongoDbPipeline) {
+    return toBuilder().setMongoDbPipeline(mongoDbPipeline).build();
+  }
+
+  @Override
+  public MongoCursor<Document> apply(MongoCollection<Document> collection) {
+    if (bucket() != null) {
+      if (mongoDbPipeline().size() == 1) {
+        mongoDbPipeline().add(bucket());
+      } else {
+        mongoDbPipeline().set(mongoDbPipeline().size() - 1, bucket());
+      }
+    }
+    return collection.aggregate(mongoDbPipeline()).iterator();
+  }
+}
diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/FindQuery.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/FindQuery.java
new file mode 100644
index 0000000..9da9d4d
--- /dev/null
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/FindQuery.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.io.mongodb;
+
+import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
+
+import com.google.auto.value.AutoValue;
+import com.mongodb.BasicDBObject;
+import com.mongodb.MongoClient;
+import com.mongodb.client.MongoCollection;
+import com.mongodb.client.MongoCursor;
+import com.mongodb.client.model.Projections;
+import java.util.Collections;
+import java.util.List;
+import javax.annotation.Nullable;
+import org.apache.beam.sdk.annotations.Experimental;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.bson.BsonDocument;
+import org.bson.Document;
+import org.bson.conversions.Bson;
+
+/** Builds a MongoDB FindQuery object. */
+@Experimental(Experimental.Kind.SOURCE_SINK)
+@AutoValue
+public abstract class FindQuery
+    implements SerializableFunction<MongoCollection<Document>, MongoCursor<Document>> {
+
+  @Nullable
+  abstract BsonDocument filters();
+
+  abstract int limit();
+
+  abstract List<String> projection();
+
+  private static Builder builder() {
+    return new AutoValue_FindQuery.Builder()
+        .setLimit(0)
+        .setProjection(Collections.emptyList())
+        .setFilters(new BsonDocument());
+  }
+
+  abstract Builder toBuilder();
+
+  public static FindQuery create() {
+    return builder().build();
+  }
+
+  @AutoValue.Builder
+  abstract static class Builder {
+    abstract Builder setFilters(@Nullable BsonDocument filters);
+
+    abstract Builder setLimit(int limit);
+
+    abstract Builder setProjection(List<String> projection);
+
+    abstract FindQuery build();
+  }
+
+  /** Sets the filters to find. */
+  private FindQuery withFilters(BsonDocument filters) {
+    return toBuilder().setFilters(filters).build();
+  }
+
+  /** Convert the Bson filters into a BsonDocument via default encoding. */
+  static BsonDocument bson2BsonDocument(Bson filters) {
+    return filters.toBsonDocument(BasicDBObject.class, MongoClient.getDefaultCodecRegistry());
+  }
+
+  /** Sets the filters to find. */
+  public FindQuery withFilters(Bson filters) {
+    return withFilters(bson2BsonDocument(filters));
+  }
+
+  /** Sets the limit of documents to find. */
+  public FindQuery withLimit(int limit) {
+    return toBuilder().setLimit(limit).build();
+  }
+
+  /** Sets the projection. */
+  public FindQuery withProjection(List<String> projection) {
+    checkArgument(projection != null, "projection can not be null");
+    return toBuilder().setProjection(projection).build();
+  }
+
+  @Override
+  public MongoCursor<Document> apply(MongoCollection<Document> collection) {
+    return collection
+        .find()
+        .filter(filters())
+        .limit(limit())
+        .projection(Projections.include(projection()))
+        .iterator();
+  }
+}
diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
index e9694f8..f16f838 100644
--- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/MongoDbIO.java
@@ -17,7 +17,7 @@
  */
 package org.apache.beam.sdk.io.mongodb;
 
-import static com.mongodb.client.model.Projections.include;
+import static org.apache.beam.sdk.io.mongodb.FindQuery.bson2BsonDocument;
 import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkArgument;
 
 import com.google.auto.value.AutoValue;
@@ -26,13 +26,18 @@ import com.mongodb.MongoBulkWriteException;
 import com.mongodb.MongoClient;
 import com.mongodb.MongoClientOptions;
 import com.mongodb.MongoClientURI;
+import com.mongodb.client.AggregateIterable;
 import com.mongodb.client.MongoCollection;
 import com.mongodb.client.MongoCursor;
 import com.mongodb.client.MongoDatabase;
+import com.mongodb.client.model.Aggregates;
+import com.mongodb.client.model.Filters;
 import com.mongodb.client.model.InsertManyOptions;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
+import java.util.stream.Collectors;
 import javax.annotation.Nullable;
 import org.apache.beam.sdk.annotations.Experimental;
 import org.apache.beam.sdk.coders.Coder;
@@ -42,12 +47,18 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.sdk.values.PDone;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting;
+import org.bson.BsonDocument;
+import org.bson.BsonInt32;
+import org.bson.BsonString;
 import org.bson.Document;
+import org.bson.conversions.Bson;
+import org.bson.types.ObjectId;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -106,9 +117,11 @@ public class MongoDbIO {
         .setKeepAlive(true)
         .setMaxConnectionIdleTime(60000)
         .setNumSplits(0)
+        .setBucketAuto(false)
         .setSslEnabled(false)
         .setIgnoreSSLCertificate(false)
         .setSslInvalidHostNameAllowed(false)
+        .setQueryFn(FindQuery.create())
         .build();
   }
 
@@ -153,19 +166,18 @@ public class MongoDbIO {
     @Nullable
     abstract String collection();
 
-    @Nullable
-    abstract String filter();
+    abstract int numSplits();
 
-    @Nullable
-    abstract List<String> projection();
+    abstract boolean bucketAuto();
 
-    abstract int numSplits();
+    abstract SerializableFunction<MongoCollection<Document>, MongoCursor<Document>> queryFn();
 
     abstract Builder builder();
 
     @AutoValue.Builder
     abstract static class Builder {
       abstract Builder setUri(String uri);
+
       /**
        * @deprecated This is deprecated in the MongoDB API and will be removed in a future version.
        */
@@ -184,11 +196,12 @@ public class MongoDbIO {
 
       abstract Builder setCollection(String collection);
 
-      abstract Builder setFilter(String filter);
+      abstract Builder setNumSplits(int numSplits);
 
-      abstract Builder setProjection(List<String> fieldNames);
+      abstract Builder setBucketAuto(boolean bucketAuto);
 
-      abstract Builder setNumSplits(int numSplits);
+      abstract Builder setQueryFn(
+          SerializableFunction<MongoCollection<Document>, MongoCursor<Document>> queryBuilder);
 
       abstract Read build();
     }
@@ -276,16 +289,40 @@ public class MongoDbIO {
       return builder().setCollection(collection).build();
     }
 
-    /** Sets a filter on the documents in a collection. */
+    /**
+     * Sets a filter on the documents in a collection.
+     *
+     * @deprecated Filtering manually is discouraged. Use {@link #withQueryFn(SerializableFunction)
+     *     with {@link FindQuery#withFilters(Bson)} as an argument to set up the projection}.
+     */
+    @Deprecated
     public Read withFilter(String filter) {
       checkArgument(filter != null, "filter can not be null");
-      return builder().setFilter(filter).build();
+      checkArgument(
+          this.queryFn().getClass() != FindQuery.class,
+          "withFilter is only supported for FindQuery API");
+      FindQuery findQuery = (FindQuery) queryFn();
+      FindQuery queryWithFilter =
+          findQuery.toBuilder().setFilters(bson2BsonDocument(Document.parse(filter))).build();
+      return builder().setQueryFn(queryWithFilter).build();
     }
 
-    /** Sets a projection on the documents in a collection. */
+    /**
+     * Sets a projection on the documents in a collection.
+     *
+     * @deprecated Use {@link #withQueryFn(SerializableFunction) with {@link
+     *     FindQuery#withProjection(List)} as an argument to set up the projection}.
+     */
+    @Deprecated
     public Read withProjection(final String... fieldNames) {
       checkArgument(fieldNames.length > 0, "projection can not be null");
-      return builder().setProjection(Arrays.asList(fieldNames)).build();
+      checkArgument(
+          this.queryFn().getClass() != FindQuery.class,
+          "withFilter is only supported for FindQuery API");
+      FindQuery findQuery = (FindQuery) queryFn();
+      FindQuery queryWithProjection =
+          findQuery.toBuilder().setProjection(Arrays.asList(fieldNames)).build();
+      return builder().setQueryFn(queryWithProjection).build();
     }
 
     /** Sets the user defined number of splits. */
@@ -294,6 +331,17 @@ public class MongoDbIO {
       return builder().setNumSplits(numSplits).build();
     }
 
+    /** Sets weather to use $bucketAuto or not. */
+    public Read withBucketAuto(boolean bucketAuto) {
+      return builder().setBucketAuto(bucketAuto).build();
+    }
+
+    /** Sets a queryFn. */
+    public Read withQueryFn(
+        SerializableFunction<MongoCollection<Document>, MongoCursor<Document>> queryBuilderFn) {
+      return builder().setQueryFn(queryBuilderFn).build();
+    }
+
     @Override
     public PCollection<Document> expand(PBegin input) {
       checkArgument(uri() != null, "withUri() is required");
@@ -311,15 +359,11 @@ public class MongoDbIO {
       builder.add(DisplayData.item("sslEnabled", sslEnabled()));
       builder.add(DisplayData.item("sslInvalidHostNameAllowed", sslInvalidHostNameAllowed()));
       builder.add(DisplayData.item("ignoreSSLCertificate", ignoreSSLCertificate()));
-
       builder.add(DisplayData.item("database", database()));
       builder.add(DisplayData.item("collection", collection()));
-      builder.addIfNotNull(DisplayData.item("filter", filter()));
-      if (projection() != null) {
-        builder.addIfNotNull(
-            DisplayData.item("projection", Arrays.toString(projection().toArray())));
-      }
       builder.add(DisplayData.item("numSplit", numSplits()));
+      builder.add(DisplayData.item("bucketAuto", bucketAuto()));
+      builder.add(DisplayData.item("queryFn", queryFn().toString()));
     }
   }
 
@@ -406,43 +450,70 @@ public class MongoDbIO {
         MongoDatabase mongoDatabase = mongoClient.getDatabase(spec.database());
 
         List<Document> splitKeys;
-        if (spec.numSplits() > 0) {
-          // the user defines his desired number of splits
-          // calculate the batch size
-          long estimatedSizeBytes =
-              getEstimatedSizeBytes(mongoClient, spec.database(), spec.collection());
-          desiredBundleSizeBytes = estimatedSizeBytes / spec.numSplits();
-        }
+        List<BoundedSource<Document>> sources = new ArrayList<>();
 
-        // the desired batch size is small, using default chunk size of 1MB
-        if (desiredBundleSizeBytes < 1024L * 1024L) {
-          desiredBundleSizeBytes = 1024L * 1024L;
-        }
+        if (spec.queryFn().getClass() == AutoValue_FindQuery.class) {
+          if (spec.bucketAuto()) {
+            splitKeys = buildAutoBuckets(mongoDatabase, spec);
+          } else {
+            if (spec.numSplits() > 0) {
+              // the user defines his desired number of splits
+              // calculate the batch size
+              long estimatedSizeBytes =
+                  getEstimatedSizeBytes(mongoClient, spec.database(), spec.collection());
+              desiredBundleSizeBytes = estimatedSizeBytes / spec.numSplits();
+            }
+
+            // the desired batch size is small, using default chunk size of 1MB
+            if (desiredBundleSizeBytes < 1024L * 1024L) {
+              desiredBundleSizeBytes = 1024L * 1024L;
+            }
+
+            // now we have the batch size (provided by user or provided by the runner)
+            // we use Mongo splitVector command to get the split keys
+            BasicDBObject splitVectorCommand = new BasicDBObject();
+            splitVectorCommand.append("splitVector", spec.database() + "." + spec.collection());
+            splitVectorCommand.append("keyPattern", new BasicDBObject().append("_id", 1));
+            splitVectorCommand.append("force", false);
+            // maxChunkSize is the Mongo partition size in MB
+            LOG.debug("Splitting in chunk of {} MB", desiredBundleSizeBytes / 1024 / 1024);
+            splitVectorCommand.append("maxChunkSize", desiredBundleSizeBytes / 1024 / 1024);
+            Document splitVectorCommandResult = mongoDatabase.runCommand(splitVectorCommand);
+            splitKeys = (List<Document>) splitVectorCommandResult.get("splitKeys");
+          }
 
-        // now we have the batch size (provided by user or provided by the runner)
-        // we use Mongo splitVector command to get the split keys
-        BasicDBObject splitVectorCommand = new BasicDBObject();
-        splitVectorCommand.append("splitVector", spec.database() + "." + spec.collection());
-        splitVectorCommand.append("keyPattern", new BasicDBObject().append("_id", 1));
-        splitVectorCommand.append("force", false);
-        // maxChunkSize is the Mongo partition size in MB
-        LOG.debug("Splitting in chunk of {} MB", desiredBundleSizeBytes / 1024 / 1024);
-        splitVectorCommand.append("maxChunkSize", desiredBundleSizeBytes / 1024 / 1024);
-        Document splitVectorCommandResult = mongoDatabase.runCommand(splitVectorCommand);
-        splitKeys = (List<Document>) splitVectorCommandResult.get("splitKeys");
+          if (splitKeys.size() < 1) {
+            LOG.debug("Split keys is low, using an unique source");
+            return Collections.singletonList(this);
+          }
 
-        List<BoundedSource<Document>> sources = new ArrayList<>();
-        if (splitKeys.size() < 1) {
-          LOG.debug("Split keys is low, using an unique source");
-          sources.add(this);
-          return sources;
-        }
+          List<String> keys = splitKeysToFilters(splitKeys);
+          for (String shardFilter : splitKeysToFilters(splitKeys)) {
+            SerializableFunction<MongoCollection<Document>, MongoCursor<Document>> queryFn =
+                spec.queryFn();
 
-        LOG.debug("Number of splits is {}", splitKeys.size());
-        for (String shardFilter : splitKeysToFilters(splitKeys, spec.filter())) {
-          sources.add(new BoundedMongoDbSource(spec.withFilter(shardFilter)));
-        }
+            BsonDocument filters = bson2BsonDocument(Document.parse(shardFilter));
+            FindQuery findQuery = (FindQuery) queryFn;
+            FindQuery queryWithFilter = findQuery.toBuilder().setFilters(filters).build();
+            sources.add(new BoundedMongoDbSource(spec.withQueryFn(queryWithFilter)));
+          }
+        } else {
+          SerializableFunction<MongoCollection<Document>, MongoCursor<Document>> queryFn =
+              spec.queryFn();
+          AggregationQuery aggregationQuery = (AggregationQuery) queryFn;
+          if (aggregationQuery.mongoDbPipeline().stream()
+              .anyMatch(s -> s.keySet().contains("$limit"))) {
+            return Collections.singletonList(this);
+          }
 
+          splitKeys = buildAutoBuckets(mongoDatabase, spec);
+
+          for (BsonDocument shardFilter : splitKeysToMatch(splitKeys)) {
+            AggregationQuery queryWithBucket =
+                aggregationQuery.toBuilder().setBucket(shardFilter).build();
+            sources.add(new BoundedMongoDbSource(spec.withQueryFn(queryWithBucket)));
+          }
+        }
         return sources;
       }
     }
@@ -470,11 +541,10 @@ public class MongoDbIO {
      * </ul>
      *
      * @param splitKeys The list of split keys.
-     * @param additionalFilter A custom (user) additional filter to append to the range filters.
      * @return A list of filters containing the ranges.
      */
     @VisibleForTesting
-    static List<String> splitKeysToFilters(List<Document> splitKeys, String additionalFilter) {
+    static List<String> splitKeysToFilters(List<Document> splitKeys) {
       ArrayList<String> filters = new ArrayList<>();
       String lowestBound = null; // lower boundary (previous split in the iteration)
       for (int i = 0; i < splitKeys.size(); i++) {
@@ -484,7 +554,7 @@ public class MongoDbIO {
           // this is the first split in the list, the filter defines
           // the range from the beginning up to this split
           rangeFilter = String.format("{ $and: [ {\"_id\":{$lte:ObjectId(\"%s\")}}", splitKey);
-          filters.add(formatFilter(rangeFilter, additionalFilter));
+          filters.add(String.format("%s ]}", rangeFilter));
         } else if (i == splitKeys.size() - 1) {
           // this is the last split in the list, the filters define
           // the range from the previous split to the current split and also
@@ -493,38 +563,93 @@ public class MongoDbIO {
               String.format(
                   "{ $and: [ {\"_id\":{$gt:ObjectId(\"%s\")," + "$lte:ObjectId(\"%s\")}}",
                   lowestBound, splitKey);
-          filters.add(formatFilter(rangeFilter, additionalFilter));
+          filters.add(String.format("%s ]}", rangeFilter));
           rangeFilter = String.format("{ $and: [ {\"_id\":{$gt:ObjectId(\"%s\")}}", splitKey);
-          filters.add(formatFilter(rangeFilter, additionalFilter));
+          filters.add(String.format("%s ]}", rangeFilter));
         } else {
           // we are between two splits
           rangeFilter =
               String.format(
                   "{ $and: [ {\"_id\":{$gt:ObjectId(\"%s\")," + "$lte:ObjectId(\"%s\")}}",
                   lowestBound, splitKey);
-          filters.add(formatFilter(rangeFilter, additionalFilter));
+          filters.add(String.format("%s ]}", rangeFilter));
         }
 
         lowestBound = splitKey;
       }
+
       return filters;
     }
 
     /**
-     * Cleanly format range filter, optionally adding the users filter if specified.
+     * Transform a list of split keys as a list of filters containing corresponding range.
+     *
+     * <p>The list of split keys contains BSon Document basically containing for example:
+     *
+     * <ul>
+     *   <li>_id: 56
+     *   <li>_id: 109
+     *   <li>_id: 256
+     * </ul>
+     *
+     * <p>This method will generate a list of range filters performing the following splits:
+     *
+     * <ul>
+     *   <li>from the beginning of the collection up to _id 56, so basically data with _id lower
+     *       than 56
+     *   <li>from _id 57 up to _id 109
+     *   <li>from _id 110 up to _id 256
+     *   <li>from _id 257 up to the end of the collection, so basically data with _id greater than
+     *       257
+     * </ul>
      *
-     * @param filter The range filter.
-     * @param additionalFilter The users filter. Null if unspecified.
-     * @return The cleanly formatted range filter.
+     * @param splitKeys The list of split keys.
+     * @return A list of filters containing the ranges.
      */
-    private static String formatFilter(String filter, @Nullable String additionalFilter) {
-      if (additionalFilter != null && !additionalFilter.isEmpty()) {
-        // user provided a filter, we append the user filter to the range filter
-        return String.format("%s,%s ]}", filter, additionalFilter);
-      } else {
-        // user didn't provide a filter, just cleanly close the range filter
-        return String.format("%s ]}", filter);
+    @VisibleForTesting
+    static List<BsonDocument> splitKeysToMatch(List<Document> splitKeys) {
+      List<Bson> aggregates = new ArrayList<>();
+      ObjectId lowestBound = null; // lower boundary (previous split in the iteration)
+      for (int i = 0; i < splitKeys.size(); i++) {
+        ObjectId splitKey = splitKeys.get(i).getObjectId("_id");
+        String rangeFilter;
+        if (i == 0) {
+          aggregates.add(Aggregates.match(Filters.lte("_id", splitKey)));
+        } else if (i == splitKeys.size() - 1) {
+          aggregates.add(Aggregates.match(Filters.and(Filters.gt("_id", lowestBound))));
+        } else {
+          aggregates.add(
+              Aggregates.match(
+                  Filters.and(Filters.gt("_id", lowestBound), Filters.lte("_id", splitKey))));
+        }
+
+        lowestBound = splitKey;
+      }
+      return aggregates.stream()
+          .map(s -> s.toBsonDocument(BasicDBObject.class, MongoClient.getDefaultCodecRegistry()))
+          .collect(Collectors.toList());
+    }
+
+    @VisibleForTesting
+    static List<Document> buildAutoBuckets(MongoDatabase mongoDatabase, Read spec) {
+      List<Document> splitKeys = new ArrayList<>();
+      MongoCollection<Document> mongoCollection = mongoDatabase.getCollection(spec.collection());
+      BsonDocument bucketAutoConfig = new BsonDocument();
+      bucketAutoConfig.put("groupBy", new BsonString("$_id"));
+      // 10 is the default number of buckets
+      bucketAutoConfig.put("buckets", new BsonInt32(spec.numSplits() > 0 ? spec.numSplits() : 10));
+      BsonDocument bucketAuto = new BsonDocument("$bucketAuto", bucketAutoConfig);
+      List<BsonDocument> aggregates = new ArrayList<>();
+      aggregates.add(bucketAuto);
+      AggregateIterable<Document> buckets = mongoCollection.aggregate(aggregates);
+
+      for (Document bucket : buckets) {
+        Document filter = new Document();
+        filter.put("_id", ((Document) bucket.get("_id")).get("min"));
+        splitKeys.add(filter);
       }
+
+      return splitKeys;
     }
   }
 
@@ -542,35 +667,12 @@ public class MongoDbIO {
     @Override
     public boolean start() {
       Read spec = source.spec;
-      client =
-          new MongoClient(
-              new MongoClientURI(
-                  spec.uri(),
-                  getOptions(
-                      spec.keepAlive(),
-                      spec.maxConnectionIdleTime(),
-                      spec.sslEnabled(),
-                      spec.sslInvalidHostNameAllowed())));
 
+      // MongoDB Connection preparation
+      client = createClient(spec);
       MongoDatabase mongoDatabase = client.getDatabase(spec.database());
-
       MongoCollection<Document> mongoCollection = mongoDatabase.getCollection(spec.collection());
-
-      if (spec.filter() == null) {
-        if (spec.projection() == null) {
-          cursor = mongoCollection.find().iterator();
-        } else {
-          cursor = mongoCollection.find().projection(include(spec.projection())).iterator();
-        }
-      } else {
-        Document bson = Document.parse(spec.filter());
-        if (spec.projection() == null) {
-          cursor = mongoCollection.find(bson).iterator();
-        } else {
-          cursor = mongoCollection.find(bson).projection(include(spec.projection())).iterator();
-        }
-      }
-
+      cursor = spec.queryFn().apply(mongoCollection);
       return advance();
     }
 
@@ -579,9 +681,8 @@ public class MongoDbIO {
       if (cursor.hasNext()) {
         current = cursor.next();
         return true;
-      } else {
-        return false;
       }
+      return false;
     }
 
     @Override
@@ -609,6 +710,17 @@ public class MongoDbIO {
         LOG.warn("Error closing MongoDB client", e);
       }
     }
+
+    private MongoClient createClient(Read spec) {
+      return new MongoClient(
+          new MongoClientURI(
+              spec.uri(),
+              getOptions(
+                  spec.keepAlive(),
+                  spec.maxConnectionIdleTime(),
+                  spec.sslEnabled(),
+                  spec.sslInvalidHostNameAllowed())));
+    }
   }
 
   /** A {@link PTransform} to write to a MongoDB database. */
@@ -617,6 +729,7 @@ public class MongoDbIO {
 
     @Nullable
     abstract String uri();
+
     /**
      * @deprecated This is deprecated in the MongoDB API and will be removed in a future version.
      */
@@ -646,6 +759,7 @@ public class MongoDbIO {
     @AutoValue.Builder
     abstract static class Builder {
       abstract Builder setUri(String uri);
+
       /**
        * @deprecated This is deprecated in the MongoDB API and will be removed in a future version.
        */
diff --git a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/SSLUtils.java b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/SSLUtils.java
index 9a778a3..2a5314e 100644
--- a/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/SSLUtils.java
+++ b/sdks/java/io/mongodb/src/main/java/org/apache/beam/sdk/io/mongodb/SSLUtils.java
@@ -26,10 +26,10 @@ import javax.net.ssl.TrustManager;
 import javax.net.ssl.X509TrustManager;
 
 /** Utility class for registration of ssl context, and to allow all certificate requests. */
-public class SSLUtils {
+class SSLUtils {
 
   /** static class to allow all requests. */
-  static TrustManager[] trustAllCerts =
+  private static final TrustManager[] trustAllCerts =
       new TrustManager[] {
         new X509TrustManager() {
           @Override
@@ -50,7 +50,7 @@ public class SSLUtils {
    *
    * @return SSLContext
    */
-  public static SSLContext ignoreSSLCertificate() {
+  static SSLContext ignoreSSLCertificate() {
     try {
       // Install the all-trusting trust manager
       SSLContext sc = SSLContext.getInstance("TLS");
diff --git a/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java b/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java
index 59efc77..b68edb2 100644
--- a/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java
+++ b/sdks/java/io/mongodb/src/test/java/org/apache/beam/sdk/io/mongodb/MongoDbIOTest.java
@@ -23,6 +23,7 @@ import static org.junit.Assert.assertFalse;
 import com.mongodb.MongoClient;
 import com.mongodb.client.MongoCollection;
 import com.mongodb.client.MongoDatabase;
+import com.mongodb.client.model.Filters;
 import de.flapdoodle.embed.mongo.MongodExecutable;
 import de.flapdoodle.embed.mongo.MongodProcess;
 import de.flapdoodle.embed.mongo.MongodStarter;
@@ -40,13 +41,16 @@ import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Count;
 import org.apache.beam.sdk.transforms.Create;
-import org.apache.beam.sdk.transforms.Filter;
 import org.apache.beam.sdk.transforms.MapElements;
 import org.apache.beam.sdk.transforms.SimpleFunction;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
 import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterators;
+import org.bson.BsonDocument;
+import org.bson.BsonInt32;
+import org.bson.BsonString;
 import org.bson.Document;
+import org.bson.types.ObjectId;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.ClassRule;
@@ -119,7 +123,7 @@ public class MongoDbIOTest {
     documents.add(new Document("_id", 56));
     documents.add(new Document("_id", 109));
     documents.add(new Document("_id", 256));
-    List<String> filters = MongoDbIO.BoundedMongoDbSource.splitKeysToFilters(documents, null);
+    List<String> filters = MongoDbIO.BoundedMongoDbSource.splitKeysToFilters(documents);
     assertEquals(4, filters.size());
     assertEquals("{ $and: [ {\"_id\":{$lte:ObjectId(\"56\")}} ]}", filters.get(0));
     assertEquals(
@@ -130,6 +134,48 @@ public class MongoDbIOTest {
   }
 
   @Test
+  public void testSplitIntoBucket() {
+    ArrayList<Document> documents = new ArrayList<>();
+    documents.add(new Document("_id", new ObjectId("52cc8f6254c5317943000005")));
+    documents.add(new Document("_id", new ObjectId("52cc8f6254c5317943000007")));
+    documents.add(new Document("_id", new ObjectId("54242e9e54c531ef8800001f")));
+    documents.add(new Document("_id", new ObjectId("54242e9e54c531ef88000020")));
+    List<BsonDocument> buckets = MongoDbIO.BoundedMongoDbSource.splitKeysToMatch(documents);
+    assertEquals(4, buckets.size());
+    assertEquals(
+        "{ \"$match\" : { \"_id\" : { \"$lte\" : { \"$oid\" : \"52cc8f6254c5317943000005\" } } } }",
+        buckets.get(0).toString());
+    assertEquals(
+        "{ \"$match\" : { \"_id\" : { \"$gt\" : { \"$oid\" : \"52cc8f6254c5317943000005\" }, \"$lte\" : { \"$oid\" : \"52cc8f6254c5317943000007\" } } } }",
+        buckets.get(1).toString());
+    assertEquals(
+        "{ \"$match\" : { \"_id\" : { \"$gt\" : { \"$oid\" : \"52cc8f6254c5317943000007\" }, \"$lte\" : { \"$oid\" : \"54242e9e54c531ef8800001f\" } } } }",
+        buckets.get(2).toString());
+    assertEquals(
+        "{ \"$match\" : { \"_id\" : { \"$gt\" : { \"$oid\" : \"54242e9e54c531ef8800001f\" } } } }",
+        buckets.get(3).toString());
+  }
+
+  @Test
+  public void testBuildAutoBuckets() {
+    List<BsonDocument> aggregates = new ArrayList<BsonDocument>();
+    aggregates.add(
+        new BsonDocument(
+            "$match",
+            new BsonDocument("country", new BsonDocument("$eq", new BsonString("England")))));
+
+    MongoDbIO.Read spec =
+        MongoDbIO.read()
+            .withUri("mongodb://localhost:" + port)
+            .withDatabase(DATABASE)
+            .withCollection(COLLECTION)
+            .withQueryFn(AggregationQuery.create().withMongoDbPipeline(aggregates));
+    MongoDatabase database = client.getDatabase(DATABASE);
+    List<Document> buckets = MongoDbIO.BoundedMongoDbSource.buildAutoBuckets(database, spec);
+    assertEquals(10, buckets.size());
+  }
+
+  @Test
   public void testFullRead() {
     PCollection<Document> output =
         pipeline.apply(
@@ -194,7 +240,7 @@ public class MongoDbIOTest {
                 .withUri("mongodb://localhost:" + port)
                 .withDatabase(DATABASE)
                 .withCollection(COLLECTION)
-                .withFilter("{\"scientist\":\"Einstein\"}"));
+                .withQueryFn(FindQuery.create().withFilters(Filters.eq("scientist", "Einstein"))));
 
     PAssert.thatSingleton(output.apply("Count", Count.globally())).isEqualTo(100L);
 
@@ -202,48 +248,63 @@ public class MongoDbIOTest {
   }
 
   @Test
-  public void testReadWithFilterAndProjection() {
+  public void testReadWithFilterAndLimit() throws Exception {
     PCollection<Document> output =
         pipeline.apply(
             MongoDbIO.read()
                 .withUri("mongodb://localhost:" + port)
                 .withDatabase(DATABASE)
                 .withCollection(COLLECTION)
-                .withFilter("{\"scientist\":\"Einstein\"}")
-                .withProjection("country", "scientist"));
+                .withNumSplits(10)
+                .withQueryFn(
+                    FindQuery.create()
+                        .withFilters(Filters.eq("scientist", "Einstein"))
+                        .withLimit(5)));
 
-    PAssert.thatSingleton(
-            output
-                .apply(
-                    "Map Scientist",
-                    Filter.by(
-                        (Document doc) ->
-                            doc.get("country") != null && doc.get("scientist") != null))
-                .apply("Count", Count.globally()))
-        .isEqualTo(100L);
+    PAssert.thatSingleton(output.apply("Count", Count.globally())).isEqualTo(5L);
 
     pipeline.run();
   }
 
   @Test
-  public void testReadWithProjection() {
+  public void testReadWithAggregate() throws Exception {
+    List<BsonDocument> aggregates = new ArrayList<BsonDocument>();
+    aggregates.add(
+        new BsonDocument(
+            "$match",
+            new BsonDocument("country", new BsonDocument("$eq", new BsonString("England")))));
+
     PCollection<Document> output =
         pipeline.apply(
             MongoDbIO.read()
                 .withUri("mongodb://localhost:" + port)
                 .withDatabase(DATABASE)
                 .withCollection(COLLECTION)
-                .withProjection("country"));
+                .withQueryFn(AggregationQuery.create().withMongoDbPipeline(aggregates)));
 
-    PAssert.thatSingleton(
-            output
-                .apply(
-                    "Map scientist",
-                    Filter.by(
-                        (Document doc) ->
-                            doc.get("country") != null && doc.get("scientist") == null))
-                .apply("Count", Count.globally()))
-        .isEqualTo(1000L);
+    PAssert.thatSingleton(output.apply("Count", Count.globally())).isEqualTo(300L);
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testReadWithAggregateWithLimit() throws Exception {
+    List<BsonDocument> aggregates = new ArrayList<BsonDocument>();
+    aggregates.add(
+        new BsonDocument(
+            "$match",
+            new BsonDocument("country", new BsonDocument("$eq", new BsonString("England")))));
+    aggregates.add(new BsonDocument("$limit", new BsonInt32(10)));
+
+    PCollection<Document> output =
+        pipeline.apply(
+            MongoDbIO.read()
+                .withUri("mongodb://localhost:" + port)
+                .withDatabase(DATABASE)
+                .withCollection(COLLECTION)
+                .withQueryFn(AggregationQuery.create().withMongoDbPipeline(aggregates)));
+
+    PAssert.thatSingleton(output.apply("Count", Count.globally())).isEqualTo(10L);
 
     pipeline.run();
   }
@@ -317,7 +378,6 @@ public class MongoDbIOTest {
     for (int i = 1; i <= n; i++) {
       int index = i % scientists.length;
       Document document = new Document();
-      document.append("_id", i);
       document.append("scientist", scientists[index]);
       document.append("country", country[index]);
       documents.add(document);