You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/08/12 16:23:58 UTC

[airflow] branch main updated: Add Mongo projections to hook and transfer (#17379)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9875757  Add Mongo projections to hook and transfer (#17379)
9875757 is described below

commit 987575787d82abf5b4e68b669fdb3bcab08965e6
Author: JavierLopezT <ja...@gmail.com>
AuthorDate: Thu Aug 12 18:23:41 2021 +0200

    Add Mongo projections to hook and transfer (#17379)
---
 .../providers/amazon/aws/transfers/mongo_to_s3.py  |  7 ++++++
 airflow/providers/mongo/hooks/mongo.py             | 29 +++++++++++-----------
 .../amazon/aws/transfers/test_mongo_to_s3.py       |  4 +--
 tests/providers/mongo/hooks/test_mongo.py          | 20 +++++++++++++--
 4 files changed, 42 insertions(+), 18 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
index f21c538..d5a39d0 100644
--- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
@@ -41,6 +41,10 @@ class MongoToS3Operator(BaseOperator):
     :type mongo_collection: str
     :param mongo_query: query to execute. A list including a dict of the query
     :type mongo_query: Union[list, dict]
+    :param mongo_projection: optional parameter to filter the returned fields by
+        the query. It can be a list of fields names to include or a dictionary
+        for excluding fields (e.g `projection={"_id": 0}`
+    :type mongo_projection: Union[list, dict]
     :param s3_bucket: reference to a specific S3 bucket to store the data
     :type s3_bucket: str
     :param s3_key: in which S3 key the file will be stored
@@ -71,6 +75,7 @@ class MongoToS3Operator(BaseOperator):
         s3_bucket: str,
         s3_key: str,
         mongo_db: Optional[str] = None,
+        mongo_projection: Optional[Union[list, dict]] = None,
         replace: bool = False,
         allow_disk_use: bool = False,
         compression: Optional[str] = None,
@@ -89,6 +94,7 @@ class MongoToS3Operator(BaseOperator):
         # Grab query and determine if we need to run an aggregate pipeline
         self.mongo_query = mongo_query
         self.is_pipeline = isinstance(self.mongo_query, list)
+        self.mongo_projection = mongo_projection
 
         self.s3_bucket = s3_bucket
         self.s3_key = s3_key
@@ -113,6 +119,7 @@ class MongoToS3Operator(BaseOperator):
             results = MongoHook(self.mongo_conn_id).find(
                 mongo_collection=self.mongo_collection,
                 query=cast(dict, self.mongo_query),
+                projection=self.mongo_projection,
                 mongo_db=self.mongo_db,
             )
 
diff --git a/airflow/providers/mongo/hooks/mongo.py b/airflow/providers/mongo/hooks/mongo.py
index 7d1bb84..90ecf4c 100644
--- a/airflow/providers/mongo/hooks/mongo.py
+++ b/airflow/providers/mongo/hooks/mongo.py
@@ -18,7 +18,7 @@
 """Hook for Mongo DB"""
 from ssl import CERT_NONE
 from types import TracebackType
-from typing import List, Optional, Type
+from typing import List, Optional, Type, Union
 
 import pymongo
 from pymongo import MongoClient, ReplaceOne
@@ -122,8 +122,8 @@ class MongoHook(BaseHook):
     ) -> pymongo.command_cursor.CommandCursor:
         """
         Runs an aggregation pipeline and returns the results
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.aggregate
-        https://api.mongodb.com/python/current/examples/aggregation.html
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.aggregate
+        https://pymongo.readthedocs.io/en/stable/examples/aggregation.html
         """
         collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
 
@@ -135,25 +135,26 @@ class MongoHook(BaseHook):
         query: dict,
         find_one: bool = False,
         mongo_db: Optional[str] = None,
+        projection: Optional[Union[list, dict]] = None,
         **kwargs,
     ) -> pymongo.cursor.Cursor:
         """
         Runs a mongo find query and returns the results
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.find
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
         """
         collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
 
         if find_one:
-            return collection.find_one(query, **kwargs)
+            return collection.find_one(query, projection, **kwargs)
         else:
-            return collection.find(query, **kwargs)
+            return collection.find(query, projection, **kwargs)
 
     def insert_one(
         self, mongo_collection: str, doc: dict, mongo_db: Optional[str] = None, **kwargs
     ) -> pymongo.results.InsertOneResult:
         """
         Inserts a single document into a mongo collection
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_one
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.insert_one
         """
         collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
 
@@ -164,7 +165,7 @@ class MongoHook(BaseHook):
     ) -> pymongo.results.InsertManyResult:
         """
         Inserts many docs into a mongo collection.
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_many
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.insert_many
         """
         collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
 
@@ -180,7 +181,7 @@ class MongoHook(BaseHook):
     ) -> pymongo.results.UpdateResult:
         """
         Updates a single document in a mongo collection.
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_one
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.update_one
 
         :param mongo_collection: The name of the collection to update.
         :type mongo_collection: str
@@ -207,7 +208,7 @@ class MongoHook(BaseHook):
     ) -> pymongo.results.UpdateResult:
         """
         Updates one or more documents in a mongo collection.
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_many
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.update_many
 
         :param mongo_collection: The name of the collection to update.
         :type mongo_collection: str
@@ -234,7 +235,7 @@ class MongoHook(BaseHook):
     ) -> pymongo.results.UpdateResult:
         """
         Replaces a single document in a mongo collection.
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.replace_one
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.replace_one
 
         .. note::
             If no ``filter_doc`` is given, it is assumed that the replacement
@@ -272,7 +273,7 @@ class MongoHook(BaseHook):
         Replaces many documents in a mongo collection.
 
         Uses bulk_write with multiple ReplaceOne operations
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write
 
         .. note::
             If no ``filter_docs``are given, it is assumed that all
@@ -314,7 +315,7 @@ class MongoHook(BaseHook):
     ) -> pymongo.results.DeleteResult:
         """
         Deletes a single document in a mongo collection.
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_one
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.delete_one
 
         :param mongo_collection: The name of the collection to delete from.
         :type mongo_collection: str
@@ -334,7 +335,7 @@ class MongoHook(BaseHook):
     ) -> pymongo.results.DeleteResult:
         """
         Deletes one or more documents in a mongo collection.
-        https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_many
+        https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.delete_many
 
         :param mongo_collection: The name of the collection to delete from.
         :type mongo_collection: str
diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
index b0bfb9a..10a8e20 100644
--- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
@@ -94,7 +94,7 @@ class TestMongoToS3Operator(unittest.TestCase):
         operator.execute(None)
 
         mock_mongo_hook.return_value.find.assert_called_once_with(
-            mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
+            mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
         )
 
         op_stringify = self.mock_operator._stringify
@@ -117,7 +117,7 @@ class TestMongoToS3Operator(unittest.TestCase):
         operator.execute(None)
 
         mock_mongo_hook.return_value.find.assert_called_once_with(
-            mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
+            mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
         )
 
         op_stringify = self.mock_operator._stringify
diff --git a/tests/providers/mongo/hooks/test_mongo.py b/tests/providers/mongo/hooks/test_mongo.py
index 8e80178..d52a3b5 100644
--- a/tests/providers/mongo/hooks/test_mongo.py
+++ b/tests/providers/mongo/hooks/test_mongo.py
@@ -251,14 +251,30 @@ class TestMongoHook(unittest.TestCase):
     @unittest.skipIf(mongomock is None, 'mongomock package not present')
     def test_find_many(self):
         collection = mongomock.MongoClient().db.collection
-        objs = [{'test_find_many_1': 'test_value'}, {'test_find_many_2': 'test_value'}]
+        objs = [{'_id': 1, 'test_find_many_1': 'test_value'}, {'_id': 2, 'test_find_many_2': 'test_value'}]
         collection.insert(objs)
 
-        result_objs = self.hook.find(collection, {}, find_one=False)
+        result_objs = self.hook.find(mongo_collection=collection, query={}, projection={}, find_one=False)
 
         assert len(list(result_objs)) > 1
 
     @unittest.skipIf(mongomock is None, 'mongomock package not present')
+    def test_find_many_with_projection(self):
+        collection = mongomock.MongoClient().db.collection
+        objs = [
+            {'_id': '1', 'test_find_many_1': 'test_value', 'field_3': 'a'},
+            {'_id': '2', 'test_find_many_2': 'test_value', 'field_3': 'b'},
+        ]
+        collection.insert(objs)
+
+        projection = {'_id': 0}
+        result_objs = self.hook.find(
+            mongo_collection=collection, query={}, projection=projection, find_one=False
+        )
+
+        self.assertRaises(KeyError, lambda x: x[0]['_id'], result_objs)
+
+    @unittest.skipIf(mongomock is None, 'mongomock package not present')
     def test_aggregate(self):
         collection = mongomock.MongoClient().db.collection
         objs = [