You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/09/18 17:41:45 UTC

[airflow] branch main updated: Add IAM Role Credentials to S3ToRedshiftTransfer and RedshiftToS3Transfer (#18156)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 27088c4  Add IAM Role Credentials to S3ToRedshiftTransfer and RedshiftToS3Transfer (#18156)
27088c4 is described below

commit 27088c4533199a19e6f810abc4e565bc8e107cf0
Author: john-jac <75...@users.noreply.github.com>
AuthorDate: Sat Sep 18 10:41:23 2021 -0700

    Add IAM Role Credentials to S3ToRedshiftTransfer and RedshiftToS3Transfer (#18156)
---
 .../amazon/aws/transfers/redshift_to_s3.py         |  14 ++-
 .../amazon/aws/transfers/s3_to_redshift.py         |  15 ++-
 .../amazon/aws/transfers/test_redshift_to_s3.py    |  83 +++++++++++++++++
 .../amazon/aws/transfers/test_s3_to_redshift.py    | 103 ++++++++++++++++++---
 4 files changed, 195 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index df64b29..a3d49c6 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -130,17 +130,23 @@ class RedshiftToS3Operator(BaseOperator):
         return f"""
                     UNLOAD ('{select_query}')
                     TO 's3://{self.s3_bucket}/{s3_key}'
-                    with credentials
+                    credentials
                     '{credentials_block}'
                     {unload_options};
         """
 
     def execute(self, context) -> None:
         postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
-        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
+        conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
+
+        credentials_block = None
+        if conn.extra_dejson.get('role_arn', False):
+            credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
+        else:
+            s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
+            credentials = s3_hook.get_credentials()
+            credentials_block = build_credentials_block(credentials)
 
-        credentials = s3_hook.get_credentials()
-        credentials_block = build_credentials_block(credentials)
         unload_options = '\n\t\t\t'.join(self.unload_options)
 
         unload_query = self._build_unload_query(
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 48d80ea..c639eba 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -125,7 +125,7 @@ class S3ToRedshiftOperator(BaseOperator):
         return f"""
                     COPY {copy_destination} {column_names}
                     FROM 's3://{self.s3_bucket}/{self.s3_key}'
-                    with credentials
+                    credentials
                     '{credentials_block}'
                     {copy_options};
         """
@@ -156,9 +156,16 @@ class S3ToRedshiftOperator(BaseOperator):
 
     def execute(self, context) -> None:
         postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
-        s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
-        credentials = s3_hook.get_credentials()
-        credentials_block = build_credentials_block(credentials)
+        conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
+
+        credentials_block = None
+        if conn.extra_dejson.get('role_arn', False):
+            credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
+        else:
+            s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
+            credentials = s3_hook.get_credentials()
+            credentials_block = build_credentials_block(credentials)
+
         copy_options = '\n\t\t\t'.join(self.copy_options)
         destination = f'{self.schema}.{self.table}'
         copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination
diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
index 89797e9..880cc11 100644
--- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
@@ -23,6 +23,7 @@ from unittest import mock
 from boto3.session import Session
 from parameterized import parameterized
 
+from airflow.models.connection import Connection
 from airflow.providers.amazon.aws.transfers.redshift_to_s3 import RedshiftToS3Operator
 from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
 from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
@@ -35,6 +36,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
             [False, "key"],
         ]
     )
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
     def test_table_unloading(
@@ -43,6 +46,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
         expected_s3_key,
         mock_run,
         mock_session,
+        mock_connection,
+        mock_hook,
     ):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
@@ -50,6 +55,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
         mock_session.return_value.access_key = access_key
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = None
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -93,6 +100,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
             [False, "key"],
         ]
     )
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
     def test_execute_sts_token(
@@ -101,6 +110,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
         expected_s3_key,
         mock_run,
         mock_session,
+        mock_connection,
+        mock_hook,
     ):
         access_key = "ASIA_aws_access_key_id"
         secret_key = "aws_secret_access_key"
@@ -109,6 +120,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
         mock_session.return_value.access_key = access_key
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = token
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -155,6 +168,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
             [None, True, "key"],
         ]
     )
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
     def test_custom_select_query_unloading(
@@ -164,6 +179,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
         expected_s3_key,
         mock_run,
         mock_session,
+        mock_connection,
+        mock_hook,
     ):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
@@ -171,6 +188,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
         mock_session.return_value.access_key = access_key
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = None
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
         s3_bucket = "bucket"
         s3_key = "key"
         unload_options = [
@@ -206,6 +225,70 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
         assert secret_key in unload_query
         assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)
 
+    @parameterized.expand(
+        [
+            [True, "key/table_"],
+            [False, "key"],
+        ]
+    )
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
+    @mock.patch("boto3.session.Session")
+    @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+    def test_table_unloading_role_arn(
+        self,
+        table_as_file_name,
+        expected_s3_key,
+        mock_run,
+        mock_session,
+        mock_connection,
+        mock_hook,
+    ):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        extra = {"role_arn": "arn:aws:iam::112233445566:role/myRole"}
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+        mock_connection.return_value = Connection(extra=extra)
+        mock_hook.return_value = Connection(extra=extra)
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        unload_options = [
+            'HEADER',
+        ]
+
+        op = RedshiftToS3Operator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            unload_options=unload_options,
+            include_header=True,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            table_as_file_name=table_as_file_name,
+            dag=None,
+        )
+
+        op.execute(None)
+
+        unload_options = '\n\t\t\t'.join(unload_options)
+        select_query = f"SELECT * FROM {schema}.{table}"
+        credentials_block = f"aws_iam_role={extra['role_arn']}"
+
+        unload_query = op._build_unload_query(
+            credentials_block, select_query, expected_s3_key, unload_options
+        )
+
+        assert mock_run.call_count == 1
+        assert extra['role_arn'] in unload_query
+        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)
+
     def test_template_fields_overrides(self):
         assert RedshiftToS3Operator.template_fields == (
             's3_bucket',
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index 1ee139e..0cf02b6 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -24,14 +24,17 @@ import pytest
 from boto3.session import Session
 
 from airflow.exceptions import AirflowException
+from airflow.models.connection import Connection
 from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator
 from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
 
 
 class TestS3ToRedshiftTransfer(unittest.TestCase):
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
-    def test_execute(self, mock_run, mock_session):
+    def test_execute(self, mock_run, mock_session, mock_connection, mock_hook):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
         mock_session.return_value = Session(access_key, secret_key)
@@ -39,6 +42,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = None
 
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -60,7 +66,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         copy_query = '''
                         COPY schema.table
                         FROM 's3://bucket/key'
-                        with credentials
+                        credentials
                         'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
                         ;
                      '''
@@ -69,9 +75,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         assert secret_key in copy_query
         assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_query)
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
-    def test_execute_with_column_list(self, mock_run, mock_session):
+    def test_execute_with_column_list(self, mock_run, mock_session, mock_connection, mock_hook):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
         mock_session.return_value = Session(access_key, secret_key)
@@ -79,6 +87,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = None
 
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -102,7 +113,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         copy_query = '''
                         COPY schema.table (column_1, column_2)
                         FROM 's3://bucket/key'
-                        with credentials
+                        credentials
                         'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
                         ;
                      '''
@@ -111,9 +122,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         assert secret_key in copy_query
         assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_query)
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
-    def test_deprecated_truncate(self, mock_run, mock_session):
+    def test_deprecated_truncate(self, mock_run, mock_session, mock_connection, mock_hook):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
         mock_session.return_value = Session(access_key, secret_key)
@@ -121,6 +134,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = None
 
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -143,7 +159,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         copy_statement = '''
                         COPY schema.table
                         FROM 's3://bucket/key'
-                        with credentials
+                        credentials
                         'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
                         ;
                      '''
@@ -158,9 +174,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
 
         assert mock_run.call_count == 1
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
-    def test_replace(self, mock_run, mock_session):
+    def test_replace(self, mock_run, mock_session, mock_connection, mock_hook):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
         mock_session.return_value = Session(access_key, secret_key)
@@ -168,6 +186,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = None
 
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -190,7 +211,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         copy_statement = '''
                         COPY schema.table
                         FROM 's3://bucket/key'
-                        with credentials
+                        credentials
                         'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
                         ;
                      '''
@@ -205,9 +226,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
 
         assert mock_run.call_count == 1
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
-    def test_upsert(self, mock_run, mock_session):
+    def test_upsert(self, mock_run, mock_session, mock_connection, mock_hook):
         access_key = "aws_access_key_id"
         secret_key = "aws_secret_access_key"
         mock_session.return_value = Session(access_key, secret_key)
@@ -215,6 +238,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = None
 
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -239,7 +265,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         copy_statement = f'''
                         COPY #{table}
                         FROM 's3://bucket/key'
-                        with credentials
+                        credentials
                         'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
                         ;
                      '''
@@ -255,9 +281,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
 
         assert mock_run.call_count == 1
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
     @mock.patch("boto3.session.Session")
     @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
-    def test_execute_sts_token(self, mock_run, mock_session):
+    def test_execute_sts_token(self, mock_run, mock_session, mock_connection, mock_hook):
         access_key = "ASIA_aws_access_key_id"
         secret_key = "aws_secret_access_key"
         token = "aws_secret_token"
@@ -266,6 +294,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         mock_session.return_value.secret_key = secret_key
         mock_session.return_value.token = token
 
+        mock_connection.return_value = Connection()
+        mock_hook.return_value = Connection()
+
         schema = "schema"
         table = "table"
         s3_bucket = "bucket"
@@ -287,7 +318,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         copy_statement = '''
                             COPY schema.table
                             FROM 's3://bucket/key'
-                            with credentials
+                            credentials
                             'aws_access_key_id=ASIA_aws_access_key_id;aws_secret_access_key=aws_secret_access_key;token=aws_secret_token'
                             ;
                          '''
@@ -297,6 +328,54 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
         assert mock_run.call_count == 1
         assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_statement)
 
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    @mock.patch("airflow.models.connection.Connection")
+    @mock.patch("boto3.session.Session")
+    @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+    def test_execute_role_arn(self, mock_run, mock_session, mock_connection, mock_hook):
+        access_key = "ASIA_aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        token = "aws_secret_token"
+        extra = {"role_arn": "arn:aws:iam::112233445566:role/myRole"}
+
+        mock_session.return_value = Session(access_key, secret_key, token)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = token
+
+        mock_connection.return_value = Connection(extra=extra)
+        mock_hook.return_value = Connection(extra=extra)
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op.execute(None)
+        copy_statement = '''
+                            COPY schema.table
+                            FROM 's3://bucket/key'
+                            credentials
+                            'aws_iam_role=arn:aws:iam::112233445566:role/myRole'
+                            ;
+                         '''
+
+        assert extra['role_arn'] in copy_statement
+        assert mock_run.call_count == 1
+        assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_statement)
+
     def test_template_fields_overrides(self):
         assert S3ToRedshiftOperator.template_fields == (
             's3_bucket',