You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ta...@apache.org on 2023/03/09 22:02:27 UTC

[airflow] branch main updated: Add support of a different AWS connection for DynamoDB (#29452)

This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3780b01fc4 Add support of a different AWS connection for DynamoDB (#29452)
3780b01fc4 is described below

commit 3780b01fc46385809423bec9ef858be5be64b703
Author: Dmytro Yurchenko <1i...@whoyz.com>
AuthorDate: Thu Mar 9 23:02:18 2023 +0100

    Add support of a different AWS connection for DynamoDB (#29452)
    
    * Add support of a different AWS connection for DynamoDB
    
    In cases when DynamoDBToS3Operator operator is used with a DynamoDB
    table and an S3 bucket in different accounts, a separate AWS
    connection is needed (i.e. if you need to assume an IAM role from a
    different account).
    
    Use source_aws_conn_id to specify AWS connection for accessing DynamoDB
    and optionally dest_aws_conn_id for S3 Bucket access with a fallback to
    source_aws_conn_id.
    
    Deprecates aws_conn_id in favour of source_aws_conn_id.
    
    * Update airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
    
    Co-authored-by: Andrey Anshin <An...@taragol.is>
    
    * Update airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
    
    Co-authored-by: Andrey Anshin <An...@taragol.is>
    
    * Update airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
    
    Co-authored-by: Andrey Anshin <An...@taragol.is>
    
    * Update airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
    
    Co-authored-by: Andrey Anshin <An...@taragol.is>
    
    * Apply suggestions from code review
    
    Co-authored-by: D. Ferruzzi <fe...@amazon.com>
    
    ---------
    
    Co-authored-by: Andrey Anshin <An...@taragol.is>
    Co-authored-by: D. Ferruzzi <fe...@amazon.com>
---
 .../amazon/aws/transfers/dynamodb_to_s3.py         |  48 +++++--
 .../amazon/aws/transfers/test_dynamodb_to_s3.py    | 150 ++++++++++++++++++++-
 2 files changed, 185 insertions(+), 13 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
index 8b8d41d5d3..d9ea01cac5 100644
--- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
@@ -22,6 +22,7 @@ DynamoDB table to S3.
 from __future__ import annotations
 
 import json
+import warnings
 from copy import copy
 from decimal import Decimal
 from os.path import getsize
@@ -30,13 +31,20 @@ from typing import IO, TYPE_CHECKING, Any, Callable, Sequence
 from uuid import uuid4
 
 from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
 from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.utils.types import NOTSET, ArgNotSet
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
+_DEPRECATION_MSG = (
+    "The aws_conn_id parameter has been deprecated. Use the source_aws_conn_id parameter instead."
+)
+
+
 class JSONEncoder(json.JSONEncoder):
     """Custom json encoder implementation"""
 
@@ -52,7 +60,10 @@ def _convert_item_to_json_bytes(item: dict[str, Any]) -> bytes:
 
 
 def _upload_file_to_s3(
-    file_obj: IO, bucket_name: str, s3_key_prefix: str, aws_conn_id: str = "aws_default"
+    file_obj: IO,
+    bucket_name: str,
+    s3_key_prefix: str,
+    aws_conn_id: str | None = AwsBaseHook.default_conn_name,
 ) -> None:
     s3_client = S3Hook(aws_conn_id=aws_conn_id).get_conn()
     file_obj.seek(0)
@@ -78,19 +89,25 @@ class DynamoDBToS3Operator(BaseOperator):
         :ref:`howto/transfer:DynamoDBToS3Operator`
 
     :param dynamodb_table_name: Dynamodb table to replicate data from
+    :param source_aws_conn_id: The Airflow connection used for AWS credentials
+        to access DynamoDB. If this is None or empty then the default boto3
+        behaviour is used. If running Airflow in a distributed manner and
+        source_aws_conn_id is None or empty, then default boto3 configuration
+        would be used (and must be maintained on each worker node).
     :param s3_bucket_name: S3 bucket to replicate data to
     :param file_size: Flush file to s3 if file size >= file_size
     :param dynamodb_scan_kwargs: kwargs pass to <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table.scan>
     :param s3_key_prefix: Prefix of s3 object key
     :param process_func: How we transforms a dynamodb item to bytes. By default we dump the json
-    :param aws_conn_id: The Airflow connection used for AWS credentials.
-        If this is None or empty then the default boto3 behaviour is used. If
-        running Airflow in a distributed manner and aws_conn_id is None or
-        empty, then default boto3 configuration would be used (and must be
-        maintained on each worker node).
+    :param dest_aws_conn_id: The Airflow connection used for AWS credentials
+        to access S3. If this is not set then the source_aws_conn_id connection is used.
+    :param aws_conn_id: The Airflow connection used for AWS credentials (deprecated; use source_aws_conn_id).
+
     """  # noqa: E501
 
     template_fields: Sequence[str] = (
+        "source_aws_conn_id",
+        "dest_aws_conn_id",
         "s3_bucket_name",
         "s3_key_prefix",
         "dynamodb_table_name",
@@ -103,12 +120,14 @@ class DynamoDBToS3Operator(BaseOperator):
         self,
         *,
         dynamodb_table_name: str,
+        source_aws_conn_id: str | None = AwsBaseHook.default_conn_name,
         s3_bucket_name: str,
         file_size: int,
         dynamodb_scan_kwargs: dict[str, Any] | None = None,
         s3_key_prefix: str = "",
         process_func: Callable[[dict[str, Any]], bytes] = _convert_item_to_json_bytes,
-        aws_conn_id: str = "aws_default",
+        dest_aws_conn_id: str | None | ArgNotSet = NOTSET,
+        aws_conn_id: str | None | ArgNotSet = NOTSET,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -118,10 +137,17 @@ class DynamoDBToS3Operator(BaseOperator):
         self.dynamodb_scan_kwargs = dynamodb_scan_kwargs
         self.s3_bucket_name = s3_bucket_name
         self.s3_key_prefix = s3_key_prefix
-        self.aws_conn_id = aws_conn_id
+        if not isinstance(aws_conn_id, ArgNotSet):
+            warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
+            self.source_aws_conn_id = aws_conn_id
+        else:
+            self.source_aws_conn_id = source_aws_conn_id
+        self.dest_aws_conn_id = (
+            self.source_aws_conn_id if isinstance(dest_aws_conn_id, ArgNotSet) else dest_aws_conn_id
+        )
 
     def execute(self, context: Context) -> None:
-        hook = DynamoDBHook(aws_conn_id=self.aws_conn_id)
+        hook = DynamoDBHook(aws_conn_id=self.source_aws_conn_id)
         table = hook.get_conn().Table(self.dynamodb_table_name)
 
         scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
@@ -135,7 +161,7 @@ class DynamoDBToS3Operator(BaseOperator):
                 raise e
             finally:
                 if err is None:
-                    _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix, self.aws_conn_id)
+                    _upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix, self.dest_aws_conn_id)
 
     def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO:
         while True:
@@ -153,7 +179,7 @@ class DynamoDBToS3Operator(BaseOperator):
 
             # Upload the file to S3 if reach file size limit
             if getsize(temp_file.name) >= self.file_size:
-                _upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix, self.aws_conn_id)
+                _upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix, self.dest_aws_conn_id)
                 temp_file.close()
 
                 temp_file = NamedTemporaryFile()
diff --git a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
index 87a143ae36..22e9a19020 100644
--- a/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
@@ -23,7 +23,11 @@ from unittest.mock import MagicMock, patch
 
 import pytest
 
-from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import DynamoDBToS3Operator, JSONEncoder
+from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import (
+    _DEPRECATION_MSG,
+    DynamoDBToS3Operator,
+    JSONEncoder,
+)
 
 
 class TestJSONEncoder:
@@ -107,6 +111,74 @@ class TestDynamodbToS3:
 
         assert [{"a": float(a)}, {"b": float(b)}] == self.output_queue
 
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
+    def test_dynamodb_to_s3_default_connection(self, mock_aws_dynamodb_hook, mock_s3_hook):
+        responses = [
+            {
+                "Items": [{"a": 1}, {"b": 2}],
+                "LastEvaluatedKey": "123",
+            },
+            {
+                "Items": [{"c": 3}],
+            },
+        ]
+        table = MagicMock()
+        table.return_value.scan.side_effect = responses
+        mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table
+
+        s3_client = MagicMock()
+        s3_client.return_value.upload_file = self.mock_upload_file
+        mock_s3_hook.return_value.get_conn = s3_client
+
+        dynamodb_to_s3_operator = DynamoDBToS3Operator(
+            task_id="dynamodb_to_s3",
+            dynamodb_table_name="airflow_rocks",
+            s3_bucket_name="airflow-bucket",
+            file_size=4000,
+        )
+
+        dynamodb_to_s3_operator.execute(context={})
+        aws_conn_id = "aws_default"
+
+        mock_s3_hook.assert_called_with(aws_conn_id=aws_conn_id)
+        mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=aws_conn_id)
+
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
+    def test_dynamodb_to_s3_with_aws_conn_id(self, mock_aws_dynamodb_hook, mock_s3_hook):
+        responses = [
+            {
+                "Items": [{"a": 1}, {"b": 2}],
+                "LastEvaluatedKey": "123",
+            },
+            {
+                "Items": [{"c": 3}],
+            },
+        ]
+        table = MagicMock()
+        table.return_value.scan.side_effect = responses
+        mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table
+
+        s3_client = MagicMock()
+        s3_client.return_value.upload_file = self.mock_upload_file
+        mock_s3_hook.return_value.get_conn = s3_client
+
+        aws_conn_id = "test-conn-id"
+        with pytest.warns(DeprecationWarning, match=_DEPRECATION_MSG):
+            dynamodb_to_s3_operator = DynamoDBToS3Operator(
+                task_id="dynamodb_to_s3",
+                dynamodb_table_name="airflow_rocks",
+                s3_bucket_name="airflow-bucket",
+                file_size=4000,
+                aws_conn_id=aws_conn_id,
+            )
+
+        dynamodb_to_s3_operator.execute(context={})
+
+        mock_s3_hook.assert_called_with(aws_conn_id=aws_conn_id)
+        mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=aws_conn_id)
+
     @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
     @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
     def test_dynamodb_to_s3_with_different_aws_conn_id(self, mock_aws_dynamodb_hook, mock_s3_hook):
@@ -133,7 +205,7 @@ class TestDynamodbToS3:
             dynamodb_table_name="airflow_rocks",
             s3_bucket_name="airflow-bucket",
             file_size=4000,
-            aws_conn_id=aws_conn_id,
+            source_aws_conn_id=aws_conn_id,
         )
 
         dynamodb_to_s3_operator.execute(context={})
@@ -142,3 +214,77 @@ class TestDynamodbToS3:
 
         mock_s3_hook.assert_called_with(aws_conn_id=aws_conn_id)
         mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=aws_conn_id)
+
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
+    def test_dynamodb_to_s3_with_two_different_connections(self, mock_aws_dynamodb_hook, mock_s3_hook):
+        responses = [
+            {
+                "Items": [{"a": 1}, {"b": 2}],
+                "LastEvaluatedKey": "123",
+            },
+            {
+                "Items": [{"c": 3}],
+            },
+        ]
+        table = MagicMock()
+        table.return_value.scan.side_effect = responses
+        mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table
+
+        s3_client = MagicMock()
+        s3_client.return_value.upload_file = self.mock_upload_file
+        mock_s3_hook.return_value.get_conn = s3_client
+
+        s3_aws_conn_id = "test-conn-id"
+        dynamodb_conn_id = "test-dynamodb-conn-id"
+        dynamodb_to_s3_operator = DynamoDBToS3Operator(
+            task_id="dynamodb_to_s3",
+            dynamodb_table_name="airflow_rocks",
+            source_aws_conn_id=dynamodb_conn_id,
+            s3_bucket_name="airflow-bucket",
+            file_size=4000,
+            dest_aws_conn_id=s3_aws_conn_id,
+        )
+
+        dynamodb_to_s3_operator.execute(context={})
+
+        assert [{"a": 1}, {"b": 2}, {"c": 3}] == self.output_queue
+
+        mock_s3_hook.assert_called_with(aws_conn_id=s3_aws_conn_id)
+        mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=dynamodb_conn_id)
+
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook")
+    @patch("airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBHook")
+    def test_dynamodb_to_s3_with_just_dest_aws_conn_id(self, mock_aws_dynamodb_hook, mock_s3_hook):
+        responses = [
+            {
+                "Items": [{"a": 1}, {"b": 2}],
+                "LastEvaluatedKey": "123",
+            },
+            {
+                "Items": [{"c": 3}],
+            },
+        ]
+        table = MagicMock()
+        table.return_value.scan.side_effect = responses
+        mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table
+
+        s3_client = MagicMock()
+        s3_client.return_value.upload_file = self.mock_upload_file
+        mock_s3_hook.return_value.get_conn = s3_client
+
+        s3_aws_conn_id = "test-conn-id"
+        dynamodb_to_s3_operator = DynamoDBToS3Operator(
+            task_id="dynamodb_to_s3",
+            dynamodb_table_name="airflow_rocks",
+            s3_bucket_name="airflow-bucket",
+            file_size=4000,
+            dest_aws_conn_id=s3_aws_conn_id,
+        )
+
+        dynamodb_to_s3_operator.execute(context={})
+
+        assert [{"a": 1}, {"b": 2}, {"c": 3}] == self.output_queue
+
+        mock_aws_dynamodb_hook.assert_called_with(aws_conn_id="aws_default")
+        mock_s3_hook.assert_called_with(aws_conn_id=s3_aws_conn_id)