You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/09/18 17:41:45 UTC
[airflow] branch main updated: Add IAM Role Credentials to
S3ToRedshiftTransfer and RedshiftToS3Transfer (#18156)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 27088c4 Add IAM Role Credentials to S3ToRedshiftTransfer and RedshiftToS3Transfer (#18156)
27088c4 is described below
commit 27088c4533199a19e6f810abc4e565bc8e107cf0
Author: john-jac <75...@users.noreply.github.com>
AuthorDate: Sat Sep 18 10:41:23 2021 -0700
Add IAM Role Credentials to S3ToRedshiftTransfer and RedshiftToS3Transfer (#18156)
---
.../amazon/aws/transfers/redshift_to_s3.py | 14 ++-
.../amazon/aws/transfers/s3_to_redshift.py | 15 ++-
.../amazon/aws/transfers/test_redshift_to_s3.py | 83 +++++++++++++++++
.../amazon/aws/transfers/test_s3_to_redshift.py | 103 ++++++++++++++++++---
4 files changed, 195 insertions(+), 20 deletions(-)
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index df64b29..a3d49c6 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -130,17 +130,23 @@ class RedshiftToS3Operator(BaseOperator):
return f"""
UNLOAD ('{select_query}')
TO 's3://{self.s3_bucket}/{s3_key}'
- with credentials
+ credentials
'{credentials_block}'
{unload_options};
"""
def execute(self, context) -> None:
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
+ conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
+
+ credentials_block = None
+ if conn.extra_dejson.get('role_arn', False):
+ credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
+ else:
+ s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
+ credentials = s3_hook.get_credentials()
+ credentials_block = build_credentials_block(credentials)
- credentials = s3_hook.get_credentials()
- credentials_block = build_credentials_block(credentials)
unload_options = '\n\t\t\t'.join(self.unload_options)
unload_query = self._build_unload_query(
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 48d80ea..c639eba 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -125,7 +125,7 @@ class S3ToRedshiftOperator(BaseOperator):
return f"""
COPY {copy_destination} {column_names}
FROM 's3://{self.s3_bucket}/{self.s3_key}'
- with credentials
+ credentials
'{credentials_block}'
{copy_options};
"""
@@ -156,9 +156,16 @@ class S3ToRedshiftOperator(BaseOperator):
def execute(self, context) -> None:
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
- s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
- credentials = s3_hook.get_credentials()
- credentials_block = build_credentials_block(credentials)
+ conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
+
+ credentials_block = None
+ if conn.extra_dejson.get('role_arn', False):
+ credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
+ else:
+ s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
+ credentials = s3_hook.get_credentials()
+ credentials_block = build_credentials_block(credentials)
+
copy_options = '\n\t\t\t'.join(self.copy_options)
destination = f'{self.schema}.{self.table}'
copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination
diff --git a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
index 89797e9..880cc11 100644
--- a/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_redshift_to_s3.py
@@ -23,6 +23,7 @@ from unittest import mock
from boto3.session import Session
from parameterized import parameterized
+from airflow.models.connection import Connection
from airflow.providers.amazon.aws.transfers.redshift_to_s3 import RedshiftToS3Operator
from airflow.providers.amazon.aws.utils.redshift import build_credentials_block
from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
@@ -35,6 +36,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
[False, "key"],
]
)
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
def test_table_unloading(
@@ -43,6 +46,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
expected_s3_key,
mock_run,
mock_session,
+ mock_connection,
+ mock_hook,
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
@@ -50,6 +55,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -93,6 +100,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
[False, "key"],
]
)
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
def test_execute_sts_token(
@@ -101,6 +110,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
expected_s3_key,
mock_run,
mock_session,
+ mock_connection,
+ mock_hook,
):
access_key = "ASIA_aws_access_key_id"
secret_key = "aws_secret_access_key"
@@ -109,6 +120,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = token
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -155,6 +168,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
[None, True, "key"],
]
)
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
def test_custom_select_query_unloading(
@@ -164,6 +179,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
expected_s3_key,
mock_run,
mock_session,
+ mock_connection,
+ mock_hook,
):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
@@ -171,6 +188,8 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
mock_session.return_value.access_key = access_key
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
s3_bucket = "bucket"
s3_key = "key"
unload_options = [
@@ -206,6 +225,70 @@ class TestRedshiftToS3Transfer(unittest.TestCase):
assert secret_key in unload_query
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)
+ @parameterized.expand(
+ [
+ [True, "key/table_"],
+ [False, "key"],
+ ]
+ )
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
+ @mock.patch("boto3.session.Session")
+ @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+ def test_table_unloading_role_arn(
+ self,
+ table_as_file_name,
+ expected_s3_key,
+ mock_run,
+ mock_session,
+ mock_connection,
+ mock_hook,
+ ):
+ access_key = "aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ extra = {"role_arn": "arn:aws:iam::112233445566:role/myRole"}
+ mock_session.return_value = Session(access_key, secret_key)
+ mock_session.return_value.access_key = access_key
+ mock_session.return_value.secret_key = secret_key
+ mock_session.return_value.token = None
+ mock_connection.return_value = Connection(extra=extra)
+ mock_hook.return_value = Connection(extra=extra)
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ unload_options = [
+ 'HEADER',
+ ]
+
+ op = RedshiftToS3Operator(
+ schema=schema,
+ table=table,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ unload_options=unload_options,
+ include_header=True,
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ table_as_file_name=table_as_file_name,
+ dag=None,
+ )
+
+ op.execute(None)
+
+ unload_options = '\n\t\t\t'.join(unload_options)
+ select_query = f"SELECT * FROM {schema}.{table}"
+ credentials_block = f"aws_iam_role={extra['role_arn']}"
+
+ unload_query = op._build_unload_query(
+ credentials_block, select_query, expected_s3_key, unload_options
+ )
+
+ assert mock_run.call_count == 1
+ assert extra['role_arn'] in unload_query
+ assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], unload_query)
+
def test_template_fields_overrides(self):
assert RedshiftToS3Operator.template_fields == (
's3_bucket',
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index 1ee139e..0cf02b6 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -24,14 +24,17 @@ import pytest
from boto3.session import Session
from airflow.exceptions import AirflowException
+from airflow.models.connection import Connection
from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator
from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
class TestS3ToRedshiftTransfer(unittest.TestCase):
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
- def test_execute(self, mock_run, mock_session):
+ def test_execute(self, mock_run, mock_session, mock_connection, mock_hook):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
@@ -39,6 +42,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -60,7 +66,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
copy_query = '''
COPY schema.table
FROM 's3://bucket/key'
- with credentials
+ credentials
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
;
'''
@@ -69,9 +75,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
assert secret_key in copy_query
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_query)
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
- def test_execute_with_column_list(self, mock_run, mock_session):
+ def test_execute_with_column_list(self, mock_run, mock_session, mock_connection, mock_hook):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
@@ -79,6 +87,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -102,7 +113,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
copy_query = '''
COPY schema.table (column_1, column_2)
FROM 's3://bucket/key'
- with credentials
+ credentials
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
;
'''
@@ -111,9 +122,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
assert secret_key in copy_query
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_query)
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
- def test_deprecated_truncate(self, mock_run, mock_session):
+ def test_deprecated_truncate(self, mock_run, mock_session, mock_connection, mock_hook):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
@@ -121,6 +134,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -143,7 +159,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
copy_statement = '''
COPY schema.table
FROM 's3://bucket/key'
- with credentials
+ credentials
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
;
'''
@@ -158,9 +174,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
assert mock_run.call_count == 1
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
- def test_replace(self, mock_run, mock_session):
+ def test_replace(self, mock_run, mock_session, mock_connection, mock_hook):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
@@ -168,6 +186,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -190,7 +211,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
copy_statement = '''
COPY schema.table
FROM 's3://bucket/key'
- with credentials
+ credentials
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
;
'''
@@ -205,9 +226,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
assert mock_run.call_count == 1
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
- def test_upsert(self, mock_run, mock_session):
+ def test_upsert(self, mock_run, mock_session, mock_connection, mock_hook):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)
@@ -215,6 +238,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = None
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -239,7 +265,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
copy_statement = f'''
COPY #{table}
FROM 's3://bucket/key'
- with credentials
+ credentials
'aws_access_key_id=aws_access_key_id;aws_secret_access_key=aws_secret_access_key'
;
'''
@@ -255,9 +281,11 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
assert mock_run.call_count == 1
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
- def test_execute_sts_token(self, mock_run, mock_session):
+ def test_execute_sts_token(self, mock_run, mock_session, mock_connection, mock_hook):
access_key = "ASIA_aws_access_key_id"
secret_key = "aws_secret_access_key"
token = "aws_secret_token"
@@ -266,6 +294,9 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
mock_session.return_value.secret_key = secret_key
mock_session.return_value.token = token
+ mock_connection.return_value = Connection()
+ mock_hook.return_value = Connection()
+
schema = "schema"
table = "table"
s3_bucket = "bucket"
@@ -287,7 +318,7 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
copy_statement = '''
COPY schema.table
FROM 's3://bucket/key'
- with credentials
+ credentials
'aws_access_key_id=ASIA_aws_access_key_id;aws_secret_access_key=aws_secret_access_key;token=aws_secret_token'
;
'''
@@ -297,6 +328,54 @@ class TestS3ToRedshiftTransfer(unittest.TestCase):
assert mock_run.call_count == 1
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_statement)
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+ @mock.patch("airflow.models.connection.Connection")
+ @mock.patch("boto3.session.Session")
+ @mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
+ def test_execute_role_arn(self, mock_run, mock_session, mock_connection, mock_hook):
+ access_key = "ASIA_aws_access_key_id"
+ secret_key = "aws_secret_access_key"
+ token = "aws_secret_token"
+ extra = {"role_arn": "arn:aws:iam::112233445566:role/myRole"}
+
+ mock_session.return_value = Session(access_key, secret_key, token)
+ mock_session.return_value.access_key = access_key
+ mock_session.return_value.secret_key = secret_key
+ mock_session.return_value.token = token
+
+ mock_connection.return_value = Connection(extra=extra)
+ mock_hook.return_value = Connection(extra=extra)
+
+ schema = "schema"
+ table = "table"
+ s3_bucket = "bucket"
+ s3_key = "key"
+ copy_options = ""
+
+ op = S3ToRedshiftOperator(
+ schema=schema,
+ table=table,
+ s3_bucket=s3_bucket,
+ s3_key=s3_key,
+ copy_options=copy_options,
+ redshift_conn_id="redshift_conn_id",
+ aws_conn_id="aws_conn_id",
+ task_id="task_id",
+ dag=None,
+ )
+ op.execute(None)
+ copy_statement = '''
+ COPY schema.table
+ FROM 's3://bucket/key'
+ credentials
+ 'aws_iam_role=arn:aws:iam::112233445566:role/myRole'
+ ;
+ '''
+
+ assert extra['role_arn'] in copy_statement
+ assert mock_run.call_count == 1
+ assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_statement)
+
def test_template_fields_overrides(self):
assert S3ToRedshiftOperator.template_fields == (
's3_bucket',