You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/08/12 16:23:58 UTC
[airflow] branch main updated: Add Mongo projections to hook and
transfer (#17379)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9875757 Add Mongo projections to hook and transfer (#17379)
9875757 is described below
commit 987575787d82abf5b4e68b669fdb3bcab08965e6
Author: JavierLopezT <ja...@gmail.com>
AuthorDate: Thu Aug 12 18:23:41 2021 +0200
Add Mongo projections to hook and transfer (#17379)
---
.../providers/amazon/aws/transfers/mongo_to_s3.py | 7 ++++++
airflow/providers/mongo/hooks/mongo.py | 29 +++++++++++-----------
.../amazon/aws/transfers/test_mongo_to_s3.py | 4 +--
tests/providers/mongo/hooks/test_mongo.py | 20 +++++++++++++--
4 files changed, 42 insertions(+), 18 deletions(-)
diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
index f21c538..d5a39d0 100644
--- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
@@ -41,6 +41,10 @@ class MongoToS3Operator(BaseOperator):
:type mongo_collection: str
:param mongo_query: query to execute. A list including a dict of the query
:type mongo_query: Union[list, dict]
+ :param mongo_projection: optional parameter to filter the returned fields by
+ the query. It can be a list of fields names to include or a dictionary
+ for excluding fields (e.g `projection={"_id": 0}`
+ :type mongo_projection: Union[list, dict]
:param s3_bucket: reference to a specific S3 bucket to store the data
:type s3_bucket: str
:param s3_key: in which S3 key the file will be stored
@@ -71,6 +75,7 @@ class MongoToS3Operator(BaseOperator):
s3_bucket: str,
s3_key: str,
mongo_db: Optional[str] = None,
+ mongo_projection: Optional[Union[list, dict]] = None,
replace: bool = False,
allow_disk_use: bool = False,
compression: Optional[str] = None,
@@ -89,6 +94,7 @@ class MongoToS3Operator(BaseOperator):
# Grab query and determine if we need to run an aggregate pipeline
self.mongo_query = mongo_query
self.is_pipeline = isinstance(self.mongo_query, list)
+ self.mongo_projection = mongo_projection
self.s3_bucket = s3_bucket
self.s3_key = s3_key
@@ -113,6 +119,7 @@ class MongoToS3Operator(BaseOperator):
results = MongoHook(self.mongo_conn_id).find(
mongo_collection=self.mongo_collection,
query=cast(dict, self.mongo_query),
+ projection=self.mongo_projection,
mongo_db=self.mongo_db,
)
diff --git a/airflow/providers/mongo/hooks/mongo.py b/airflow/providers/mongo/hooks/mongo.py
index 7d1bb84..90ecf4c 100644
--- a/airflow/providers/mongo/hooks/mongo.py
+++ b/airflow/providers/mongo/hooks/mongo.py
@@ -18,7 +18,7 @@
"""Hook for Mongo DB"""
from ssl import CERT_NONE
from types import TracebackType
-from typing import List, Optional, Type
+from typing import List, Optional, Type, Union
import pymongo
from pymongo import MongoClient, ReplaceOne
@@ -122,8 +122,8 @@ class MongoHook(BaseHook):
) -> pymongo.command_cursor.CommandCursor:
"""
Runs an aggregation pipeline and returns the results
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.aggregate
- https://api.mongodb.com/python/current/examples/aggregation.html
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.aggregate
+ https://pymongo.readthedocs.io/en/stable/examples/aggregation.html
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
@@ -135,25 +135,26 @@ class MongoHook(BaseHook):
query: dict,
find_one: bool = False,
mongo_db: Optional[str] = None,
+ projection: Optional[Union[list, dict]] = None,
**kwargs,
) -> pymongo.cursor.Cursor:
"""
Runs a mongo find query and returns the results
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.find
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.find
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
if find_one:
- return collection.find_one(query, **kwargs)
+ return collection.find_one(query, projection, **kwargs)
else:
- return collection.find(query, **kwargs)
+ return collection.find(query, projection, **kwargs)
def insert_one(
self, mongo_collection: str, doc: dict, mongo_db: Optional[str] = None, **kwargs
) -> pymongo.results.InsertOneResult:
"""
Inserts a single document into a mongo collection
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_one
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.insert_one
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
@@ -164,7 +165,7 @@ class MongoHook(BaseHook):
) -> pymongo.results.InsertManyResult:
"""
Inserts many docs into a mongo collection.
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.insert_many
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.insert_many
"""
collection = self.get_collection(mongo_collection, mongo_db=mongo_db)
@@ -180,7 +181,7 @@ class MongoHook(BaseHook):
) -> pymongo.results.UpdateResult:
"""
Updates a single document in a mongo collection.
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_one
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.update_one
:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
@@ -207,7 +208,7 @@ class MongoHook(BaseHook):
) -> pymongo.results.UpdateResult:
"""
Updates one or more documents in a mongo collection.
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.update_many
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.update_many
:param mongo_collection: The name of the collection to update.
:type mongo_collection: str
@@ -234,7 +235,7 @@ class MongoHook(BaseHook):
) -> pymongo.results.UpdateResult:
"""
Replaces a single document in a mongo collection.
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.replace_one
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.replace_one
.. note::
If no ``filter_doc`` is given, it is assumed that the replacement
@@ -272,7 +273,7 @@ class MongoHook(BaseHook):
Replaces many documents in a mongo collection.
Uses bulk_write with multiple ReplaceOne operations
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.bulk_write
.. note::
If no ``filter_docs``are given, it is assumed that all
@@ -314,7 +315,7 @@ class MongoHook(BaseHook):
) -> pymongo.results.DeleteResult:
"""
Deletes a single document in a mongo collection.
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_one
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.delete_one
:param mongo_collection: The name of the collection to delete from.
:type mongo_collection: str
@@ -334,7 +335,7 @@ class MongoHook(BaseHook):
) -> pymongo.results.DeleteResult:
"""
Deletes one or more documents in a mongo collection.
- https://api.mongodb.com/python/current/api/pymongo/collection.html#pymongo.collection.Collection.delete_many
+ https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.delete_many
:param mongo_collection: The name of the collection to delete from.
:type mongo_collection: str
diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
index b0bfb9a..10a8e20 100644
--- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
@@ -94,7 +94,7 @@ class TestMongoToS3Operator(unittest.TestCase):
operator.execute(None)
mock_mongo_hook.return_value.find.assert_called_once_with(
- mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
+ mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
)
op_stringify = self.mock_operator._stringify
@@ -117,7 +117,7 @@ class TestMongoToS3Operator(unittest.TestCase):
operator.execute(None)
mock_mongo_hook.return_value.find.assert_called_once_with(
- mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
+ mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, projection=None
)
op_stringify = self.mock_operator._stringify
diff --git a/tests/providers/mongo/hooks/test_mongo.py b/tests/providers/mongo/hooks/test_mongo.py
index 8e80178..d52a3b5 100644
--- a/tests/providers/mongo/hooks/test_mongo.py
+++ b/tests/providers/mongo/hooks/test_mongo.py
@@ -251,14 +251,30 @@ class TestMongoHook(unittest.TestCase):
@unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_find_many(self):
collection = mongomock.MongoClient().db.collection
- objs = [{'test_find_many_1': 'test_value'}, {'test_find_many_2': 'test_value'}]
+ objs = [{'_id': 1, 'test_find_many_1': 'test_value'}, {'_id': 2, 'test_find_many_2': 'test_value'}]
collection.insert(objs)
- result_objs = self.hook.find(collection, {}, find_one=False)
+ result_objs = self.hook.find(mongo_collection=collection, query={}, projection={}, find_one=False)
assert len(list(result_objs)) > 1
@unittest.skipIf(mongomock is None, 'mongomock package not present')
+ def test_find_many_with_projection(self):
+ collection = mongomock.MongoClient().db.collection
+ objs = [
+ {'_id': '1', 'test_find_many_1': 'test_value', 'field_3': 'a'},
+ {'_id': '2', 'test_find_many_2': 'test_value', 'field_3': 'b'},
+ ]
+ collection.insert(objs)
+
+ projection = {'_id': 0}
+ result_objs = self.hook.find(
+ mongo_collection=collection, query={}, projection=projection, find_one=False
+ )
+
+ self.assertRaises(KeyError, lambda x: x[0]['_id'], result_objs)
+
+ @unittest.skipIf(mongomock is None, 'mongomock package not present')
def test_aggregate(self):
collection = mongomock.MongoClient().db.collection
objs = [