You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by fe...@apache.org on 2021/01/19 12:26:16 UTC

[airflow] branch master updated: AllowDiskUse parameter and docs in MongotoS3Operator (#12033)

This is an automated email from the ASF dual-hosted git repository.

feluelle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new c065d32  AllowDiskUse parameter and docs in MongotoS3Operator (#12033)
c065d32 is described below

commit c065d32189bfee80ab938d96ad74f6492e9c9b24
Author: JavierLopezT <ja...@gmail.com>
AuthorDate: Tue Jan 19 13:25:53 2021 +0100

    AllowDiskUse parameter and docs in MongotoS3Operator (#12033)
    
    Co-authored-by: RosterIn <48...@users.noreply.github.com>
    Co-authored-by: javier.lopez <ja...@promocionesfarma.com>
---
 .../providers/amazon/aws/transfers/mongo_to_s3.py  | 70 +++++++++++++++-------
 .../amazon/aws/transfers/test_mongo_to_s3.py       | 17 ++++--
 2 files changed, 58 insertions(+), 29 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
index fe5ccfa..8c72123 100644
--- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
@@ -16,6 +16,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import json
+import warnings
 from typing import Any, Iterable, Optional, Union, cast
 
 from bson import json_util
@@ -25,57 +26,80 @@ from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 from airflow.providers.mongo.hooks.mongo import MongoHook
 from airflow.utils.decorators import apply_defaults
 
+_DEPRECATION_MSG = (
+    "The s3_conn_id parameter has been deprecated. You should pass instead the aws_conn_id parameter."
+)
+
 
 class MongoToS3Operator(BaseOperator):
-    """
-    Mongo -> S3
-        A more specific baseOperator meant to move data
-        from mongo via pymongo to s3 via boto
-
-        things to note
-                .execute() is written to depend on .transform()
-                .transform() is meant to be extended by child classes
-                to perform transformations unique to those operators needs
+    """Operator meant to move data from mongo via pymongo to s3 via boto.
+
+    :param mongo_conn_id: reference to a specific mongo connection
+    :type mongo_conn_id: str
+    :param aws_conn_id: reference to a specific S3 connection
+    :type aws_conn_id: str
+    :param mongo_collection: reference to a specific collection in your mongo db
+    :type mongo_collection: str
+    :param mongo_query: query to execute. A list including a dict of the query
+    :type mongo_query: list
+    :param s3_bucket: reference to a specific S3 bucket to store the data
+    :type s3_bucket: str
+    :param s3_key: in which S3 key the file will be stored
+    :type s3_key: str
+    :param mongo_db: reference to a specific mongo database
+    :type mongo_db: str
+    :param replace: whether or not to replace the file in S3 if it previously existed
+    :type replace: bool
+    :param allow_disk_use: in the case you are retrieving a lot of data, you may have
+        to use the disk to save it instead of saving all in the RAM
+    :type allow_disk_use: bool
+    :param compression: type of compression to use for output file in S3. Currently only gzip is supported.
+    :type compression: str
     """
 
-    template_fields = ['s3_key', 'mongo_query', 'mongo_collection']
+    template_fields = ('s3_bucket', 's3_key', 'mongo_query', 'mongo_collection')
     # pylint: disable=too-many-instance-attributes
 
     @apply_defaults
     def __init__(
         self,
         *,
-        mongo_conn_id: str,
-        s3_conn_id: str,
+        s3_conn_id: Optional[str] = None,
+        mongo_conn_id: str = 'mongo_default',
+        aws_conn_id: str = 'aws_default',
         mongo_collection: str,
         mongo_query: Union[list, dict],
         s3_bucket: str,
         s3_key: str,
         mongo_db: Optional[str] = None,
         replace: bool = False,
+        allow_disk_use: bool = False,
         compression: Optional[str] = None,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
-        # Conn Ids
+        if s3_conn_id:
+            warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
+            aws_conn_id = s3_conn_id
+
         self.mongo_conn_id = mongo_conn_id
-        self.s3_conn_id = s3_conn_id
-        # Mongo Query Settings
+        self.aws_conn_id = aws_conn_id
         self.mongo_db = mongo_db
         self.mongo_collection = mongo_collection
+
         # Grab query and determine if we need to run an aggregate pipeline
         self.mongo_query = mongo_query
         self.is_pipeline = isinstance(self.mongo_query, list)
 
-        # S3 Settings
         self.s3_bucket = s3_bucket
         self.s3_key = s3_key
         self.replace = replace
+        self.allow_disk_use = allow_disk_use
         self.compression = compression
 
     def execute(self, context) -> bool:
-        """Executed by task_instance at runtime"""
-        s3_conn = S3Hook(self.s3_conn_id)
+        """Is written to depend on transform method"""
+        s3_conn = S3Hook(self.aws_conn_id)
 
         # Grab collection and execute query according to whether or not it is a pipeline
         if self.is_pipeline:
@@ -83,6 +107,7 @@ class MongoToS3Operator(BaseOperator):
                 mongo_collection=self.mongo_collection,
                 aggregate_query=cast(list, self.mongo_query),
                 mongo_db=self.mongo_db,
+                allowDiskUse=self.allow_disk_use,
             )
 
         else:
@@ -90,12 +115,12 @@ class MongoToS3Operator(BaseOperator):
                 mongo_collection=self.mongo_collection,
                 query=cast(dict, self.mongo_query),
                 mongo_db=self.mongo_db,
+                allowDiskUse=self.allow_disk_use,
             )
 
         # Performs transform then stringifies the docs results into json format
         docs_str = self._stringify(self.transform(results))
 
-        # Load Into S3
         s3_conn.load_string(
             string_data=docs_str,
             key=self.s3_key,
@@ -104,8 +129,6 @@ class MongoToS3Operator(BaseOperator):
             compression=self.compression,
         )
 
-        return True
-
     @staticmethod
     def _stringify(iterable: Iterable, joinable: str = '\n') -> str:
         """
@@ -116,9 +139,10 @@ class MongoToS3Operator(BaseOperator):
 
     @staticmethod
     def transform(docs: Any) -> Any:
-        """
+        """This method is meant to be extended by child classes
+        to perform transformations unique to those operators needs.
         Processes pyMongo cursor and returns an iterable with each element being
-                a JSON serializable dictionary
+        a JSON serializable dictionary
 
         Base transform() assumes no processing is needed
         ie. docs is a pyMongo cursor of documents and cursor just
diff --git a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
index 746b1dd..9e9811a 100644
--- a/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_mongo_to_s3.py
@@ -25,7 +25,7 @@ from airflow.utils import timezone
 
 TASK_ID = 'test_mongo_to_s3_operator'
 MONGO_CONN_ID = 'default_mongo'
-S3_CONN_ID = 'default_s3'
+AWS_CONN_ID = 'default_s3'
 MONGO_COLLECTION = 'example_collection'
 MONGO_QUERY = {"$lt": "{{ ts + 'Z' }}"}
 S3_BUCKET = 'example_bucket'
@@ -48,7 +48,7 @@ class TestMongoToS3Operator(unittest.TestCase):
         self.mock_operator = MongoToS3Operator(
             task_id=TASK_ID,
             mongo_conn_id=MONGO_CONN_ID,
-            s3_conn_id=S3_CONN_ID,
+            aws_conn_id=AWS_CONN_ID,
             mongo_collection=MONGO_COLLECTION,
             mongo_query=MONGO_QUERY,
             s3_bucket=S3_BUCKET,
@@ -60,7 +60,7 @@ class TestMongoToS3Operator(unittest.TestCase):
     def test_init(self):
         assert self.mock_operator.task_id == TASK_ID
         assert self.mock_operator.mongo_conn_id == MONGO_CONN_ID
-        assert self.mock_operator.s3_conn_id == S3_CONN_ID
+        assert self.mock_operator.aws_conn_id == AWS_CONN_ID
         assert self.mock_operator.mongo_collection == MONGO_COLLECTION
         assert self.mock_operator.mongo_query == MONGO_QUERY
         assert self.mock_operator.s3_bucket == S3_BUCKET
@@ -68,7 +68,12 @@ class TestMongoToS3Operator(unittest.TestCase):
         assert self.mock_operator.compression == COMPRESSION
 
     def test_template_field_overrides(self):
-        assert self.mock_operator.template_fields == ['s3_key', 'mongo_query', 'mongo_collection']
+        assert self.mock_operator.template_fields == (
+            's3_bucket',
+            's3_key',
+            'mongo_query',
+            'mongo_collection',
+        )
 
     def test_render_template(self):
         ti = TaskInstance(self.mock_operator, DEFAULT_DATE)
@@ -89,7 +94,7 @@ class TestMongoToS3Operator(unittest.TestCase):
         operator.execute(None)
 
         mock_mongo_hook.return_value.find.assert_called_once_with(
-            mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
+            mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None, allowDiskUse=False
         )
 
         op_stringify = self.mock_operator._stringify
@@ -112,7 +117,7 @@ class TestMongoToS3Operator(unittest.TestCase):
         operator.execute(None)
 
         mock_mongo_hook.return_value.find.assert_called_once_with(
-            mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
+            allowDiskUse=False, mongo_collection=MONGO_COLLECTION, query=MONGO_QUERY, mongo_db=None
         )
 
         op_stringify = self.mock_operator._stringify