You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/11/18 13:08:01 UTC

incubator-airflow git commit: [AIRFLOW-1795] Correctly call S3Hook after migration to boto3

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 54c03f326 -> 98df0d6e3


[AIRFLOW-1795] Correctly call S3Hook after migration to boto3

In the migration of S3Hook to boto3 the connection
ID parameter changed
to `aws_conn_id`. This fixes the uses of
`s3_conn_id` in the code base
and adds a note to UPDATING.md about the change.

In correcting the tests for S3ToHiveTransfer I
noticed that
S3Hook.get_key was returning a dictionary, rather
then the S3.Object as
mentioned in it's doc string. The important thing
that was missing was
ability to get the key name from the return a call
to get_wildcard_key.

Closes #2795 from
ashb/AIRFLOW-1795-s3hook_boto3_fixes


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/98df0d6e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/98df0d6e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/98df0d6e

Branch: refs/heads/master
Commit: 98df0d6e3b2e2b439ab46d6c9ba736777202414a
Parents: 54c03f3
Author: Ash Berlin-Taylor <as...@firemirror.com>
Authored: Sat Nov 18 14:07:38 2017 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Sat Nov 18 14:07:38 2017 +0100

----------------------------------------------------------------------
 UPDATING.md                                     | 13 +++++++-
 airflow/hooks/S3_hook.py                        | 14 ++++----
 airflow/operators/redshift_to_s3_operator.py    | 10 +++---
 airflow/operators/s3_file_transform_operator.py | 20 ++++++------
 airflow/operators/s3_to_hive_operator.py        | 15 ++++-----
 airflow/operators/sensors.py                    | 18 +++++------
 tests/operators/s3_to_hive_operator.py          | 34 +++++++++++---------
 7 files changed, 69 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98df0d6e/UPDATING.md
----------------------------------------------------------------------
diff --git a/UPDATING.md b/UPDATING.md
index 6abcaf7..3c7d549 100644
--- a/UPDATING.md
+++ b/UPDATING.md
@@ -6,12 +6,23 @@ assists people when migrating to a new version.
 ## Airflow 1.9
 
 ### SSH Hook updates, along with new SSH Operator & SFTP Operator
-  SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible.
+
+SSH Hook now uses Paramiko library to create ssh client connection, instead of sub-process based ssh command execution previously (<1.9.0), so this is backward incompatible.
   - update SSHHook constructor
   - use SSHOperator class in place of SSHExecuteOperator which is removed now. Refer test_ssh_operator.py for usage info.
   - SFTPOperator is added to perform secure file transfer from serverA to serverB. Refer test_sftp_operator.py.py for usage info.
   - No updates are required if you are using ftpHook, it will continue work as is.
 
+### S3Hook switched to use Boto3
+
+The airflow.hooks.S3_hook.S3Hook has been switched to use boto3 instead of the older boto (a.k.a. boto2). This result in a few backwards incompatible changes to the following classes: S3Hook:
+  - the constructors no longer accepts `s3_conn_id`. It is now called `aws_conn_id`.
+  - the default conneciton is now "aws_default" instead of "s3_default"
+  - the return type of objects returned by `get_bucket` is now boto3.s3.Bucket
+  - the return type of `get_key`, and `get_wildcard_key` is now an boto3.S3.Object.
+
+If you are using any of these in your DAGs and specify a connection ID you will need to update the parameter name for the connection to "aws_conn_id": S3ToHiveTransfer, S3PrefixSensor, S3KeySensor, RedshiftToS3Transfer.
+
 ### Logging update
 
 The logging structure of Airflow has been rewritten to make configuration easier and the logging system more transparent.

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98df0d6e/airflow/hooks/S3_hook.py
----------------------------------------------------------------------
diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py
index b16566f..226b520 100644
--- a/airflow/hooks/S3_hook.py
+++ b/airflow/hooks/S3_hook.py
@@ -123,7 +123,7 @@ class S3Hook(AwsHook):
 
     def get_key(self, key, bucket_name=None):
         """
-        Returns a boto3.S3.Key object
+        Returns a boto3.s3.Object
 
         :param key: the path to the key
         :type key: str
@@ -132,8 +132,10 @@ class S3Hook(AwsHook):
         """
         if not bucket_name:
             (bucket_name, key) = self.parse_s3_url(key)
-            
-        return self.get_conn().get_object(Bucket=bucket_name, Key=key)
+
+        obj = self.get_resource_type('s3').Object(bucket_name, key)
+        obj.load()
+        return obj
 
     def read_key(self, key, bucket_name=None):
         """
@@ -144,9 +146,9 @@ class S3Hook(AwsHook):
         :param bucket_name: Name of the bucket in which the file is stored
         :type bucket_name: str
         """
-        
+
         obj = self.get_key(key, bucket_name)
-        return obj['Body'].read().decode('utf-8')    
+        return obj.get()['Body'].read().decode('utf-8')
 
     def check_for_wildcard_key(self,
                                wildcard_key, bucket_name=None, delimiter=''):
@@ -159,7 +161,7 @@ class S3Hook(AwsHook):
 
     def get_wildcard_key(self, wildcard_key, bucket_name=None, delimiter=''):
         """
-        Returns a boto3.s3.Key object matching the regular expression
+        Returns a boto3.s3.Object object matching the regular expression
 
         :param regex_key: the path to the key
         :type regex_key: str

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98df0d6e/airflow/operators/redshift_to_s3_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/redshift_to_s3_operator.py b/airflow/operators/redshift_to_s3_operator.py
index 683ff9c..5553a2a 100644
--- a/airflow/operators/redshift_to_s3_operator.py
+++ b/airflow/operators/redshift_to_s3_operator.py
@@ -30,8 +30,8 @@ class RedshiftToS3Transfer(BaseOperator):
     :type s3_key: string
     :param redshift_conn_id: reference to a specific redshift database
     :type redshift_conn_id: string
-    :param s3_conn_id: reference to a specific S3 connection
-    :type s3_conn_id: string
+    :param aws_conn_id: reference to a specific S3 connection
+    :type aws_conn_id: string
     :param options: reference to a list of UNLOAD options
     :type options: list
     """
@@ -48,7 +48,7 @@ class RedshiftToS3Transfer(BaseOperator):
             s3_bucket,
             s3_key,
             redshift_conn_id='redshift_default',
-            s3_conn_id='s3_default',
+            aws_conn_id='aws_default',
             unload_options=tuple(),
             autocommit=False,
             parameters=None,
@@ -59,14 +59,14 @@ class RedshiftToS3Transfer(BaseOperator):
         self.s3_bucket = s3_bucket
         self.s3_key = s3_key
         self.redshift_conn_id = redshift_conn_id
-        self.s3_conn_id = s3_conn_id
+        self.aws_conn_id = aws_conn_id
         self.unload_options = unload_options
         self.autocommit = autocommit
         self.parameters = parameters
 
     def execute(self, context):
         self.hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
-        self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
+        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
         a_key, s_key = self.s3.get_credentials()
         unload_options = '\n\t\t\t'.join(self.unload_options)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98df0d6e/airflow/operators/s3_file_transform_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py
index 68c733c..e105e3d 100644
--- a/airflow/operators/s3_file_transform_operator.py
+++ b/airflow/operators/s3_file_transform_operator.py
@@ -37,12 +37,12 @@ class S3FileTransformOperator(BaseOperator):
 
     :param source_s3_key: The key to be retrieved from S3
     :type source_s3_key: str
-    :param source_s3_conn_id: source s3 connection
-    :type source_s3_conn_id: str
+    :param source_aws_conn_id: source s3 connection
+    :type source_aws_conn_id: str
     :param dest_s3_key: The key to be written from S3
     :type dest_s3_key: str
-    :param dest_s3_conn_id: destination s3 connection
-    :type dest_s3_conn_id: str
+    :param dest_aws_conn_id: destination s3 connection
+    :type dest_aws_conn_id: str
     :param replace: Replace dest S3 key if it already exists
     :type replace: bool
     :param transform_script: location of the executable transformation script
@@ -59,21 +59,21 @@ class S3FileTransformOperator(BaseOperator):
             source_s3_key,
             dest_s3_key,
             transform_script,
-            source_s3_conn_id='s3_default',
-            dest_s3_conn_id='s3_default',
+            source_aws_conn_id='aws_default',
+            dest_aws_conn_id='aws_default',
             replace=False,
             *args, **kwargs):
         super(S3FileTransformOperator, self).__init__(*args, **kwargs)
         self.source_s3_key = source_s3_key
-        self.source_s3_conn_id = source_s3_conn_id
+        self.source_aws_conn_id = source_aws_conn_id
         self.dest_s3_key = dest_s3_key
-        self.dest_s3_conn_id = dest_s3_conn_id
+        self.dest_aws_conn_id = dest_aws_conn_id
         self.replace = replace
         self.transform_script = transform_script
 
     def execute(self, context):
-        source_s3 = S3Hook(s3_conn_id=self.source_s3_conn_id)
-        dest_s3 = S3Hook(s3_conn_id=self.dest_s3_conn_id)
+        source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id)
+        dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id)
         self.log.info("Downloading source S3 file %s", self.source_s3_key)
         if not source_s3.check_for_key(self.source_s3_key):
             raise AirflowException("The source key {0} does not exist".format(self.source_s3_key))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98df0d6e/airflow/operators/s3_to_hive_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py
index 2b4aceb..148c643 100644
--- a/airflow/operators/s3_to_hive_operator.py
+++ b/airflow/operators/s3_to_hive_operator.py
@@ -71,8 +71,8 @@ class S3ToHiveTransfer(BaseOperator):
     :type wildcard_match: bool
     :param delimiter: field delimiter in the file
     :type delimiter: str
-    :param s3_conn_id: source s3 connection
-    :type s3_conn_id: str
+    :param aws_conn_id: source s3 connection
+    :type aws_conn_id: str
     :param hive_cli_conn_id: destination hive connection
     :type hive_cli_conn_id: str
     :param input_compressed: Boolean to determine if file decompression is
@@ -99,7 +99,7 @@ class S3ToHiveTransfer(BaseOperator):
             headers=False,
             check_headers=False,
             wildcard_match=False,
-            s3_conn_id='s3_default',
+            aws_conn_id='aws_default',
             hive_cli_conn_id='hive_cli_default',
             input_compressed=False,
             tblproperties=None,
@@ -116,7 +116,7 @@ class S3ToHiveTransfer(BaseOperator):
         self.check_headers = check_headers
         self.wildcard_match = wildcard_match
         self.hive_cli_conn_id = hive_cli_conn_id
-        self.s3_conn_id = s3_conn_id
+        self.aws_conn_id = aws_conn_id
         self.input_compressed = input_compressed
         self.tblproperties = tblproperties
 
@@ -127,7 +127,7 @@ class S3ToHiveTransfer(BaseOperator):
 
     def execute(self, context):
         # Downloading file from S3
-        self.s3 = S3Hook(s3_conn_id=self.s3_conn_id)
+        self.s3 = S3Hook(aws_conn_id=self.aws_conn_id)
         self.hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
         self.log.info("Downloading S3 file")
 
@@ -143,14 +143,13 @@ class S3ToHiveTransfer(BaseOperator):
             s3_key_object = self.s3.get_key(self.s3_key)
         root, file_ext = os.path.splitext(s3_key_object.key)
         with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir,\
-                NamedTemporaryFile(mode="w",
+                NamedTemporaryFile(mode="wb",
                                    dir=tmp_dir,
                                    suffix=file_ext) as f:
             self.log.info("Dumping S3 key {0} contents to local file {1}"
                           .format(s3_key_object.key, f.name))
-            s3_key_object.get_contents_to_file(f)
+            s3_key_object.download_fileobj(f)
             f.flush()
-            self.s3.connection.close()
             if not self.headers:
                 self.log.info("Loading file %s into Hive", f.name)
                 self.hive.load_file(

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98df0d6e/airflow/operators/sensors.py
----------------------------------------------------------------------
diff --git a/airflow/operators/sensors.py b/airflow/operators/sensors.py
index da7a62f..bd073b8 100644
--- a/airflow/operators/sensors.py
+++ b/airflow/operators/sensors.py
@@ -501,8 +501,8 @@ class S3KeySensor(BaseSensorOperator):
     :param wildcard_match: whether the bucket_key should be interpreted as a
         Unix wildcard pattern
     :type wildcard_match: bool
-    :param s3_conn_id: a reference to the s3 connection
-    :type s3_conn_id: str
+    :param aws_conn_id: a reference to the s3 connection
+    :type aws_conn_id: str
     """
     template_fields = ('bucket_key', 'bucket_name')
 
@@ -511,7 +511,7 @@ class S3KeySensor(BaseSensorOperator):
             self, bucket_key,
             bucket_name=None,
             wildcard_match=False,
-            s3_conn_id='s3_default',
+            aws_conn_id='aws_default',
             *args, **kwargs):
         super(S3KeySensor, self).__init__(*args, **kwargs)
         # Parse
@@ -528,11 +528,11 @@ class S3KeySensor(BaseSensorOperator):
         self.bucket_name = bucket_name
         self.bucket_key = bucket_key
         self.wildcard_match = wildcard_match
-        self.s3_conn_id = s3_conn_id
+        self.aws_conn_id = aws_conn_id
 
     def poke(self, context):
         from airflow.hooks.S3_hook import S3Hook
-        hook = S3Hook(s3_conn_id=self.s3_conn_id)
+        hook = S3Hook(aws_conn_id=self.aws_conn_id)
         full_url = "s3://" + self.bucket_name + "/" + self.bucket_key
         self.log.info('Poking for key : {full_url}'.format(**locals()))
         if self.wildcard_match:
@@ -565,7 +565,7 @@ class S3PrefixSensor(BaseSensorOperator):
     def __init__(
             self, bucket_name,
             prefix, delimiter='/',
-            s3_conn_id='s3_default',
+            aws_conn_id='aws_default',
             *args, **kwargs):
         super(S3PrefixSensor, self).__init__(*args, **kwargs)
         # Parse
@@ -573,13 +573,13 @@ class S3PrefixSensor(BaseSensorOperator):
         self.prefix = prefix
         self.delimiter = delimiter
         self.full_url = "s3://" + bucket_name + '/' + prefix
-        self.s3_conn_id = s3_conn_id
+        self.aws_conn_id = aws_conn_id
 
     def poke(self, context):
         self.log.info('Poking for prefix : {self.prefix}\n'
-                     'in bucket s3://{self.bucket_name}'.format(**locals()))
+                      'in bucket s3://{self.bucket_name}'.format(**locals()))
         from airflow.hooks.S3_hook import S3Hook
-        hook = S3Hook(s3_conn_id=self.s3_conn_id)
+        hook = S3Hook(aws_conn_id=self.aws_conn_id)
         return hook.check_for_prefix(
             prefix=self.prefix,
             delimiter=self.delimiter,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/98df0d6e/tests/operators/s3_to_hive_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/s3_to_hive_operator.py b/tests/operators/s3_to_hive_operator.py
index faab11e..021c9c4 100644
--- a/tests/operators/s3_to_hive_operator.py
+++ b/tests/operators/s3_to_hive_operator.py
@@ -32,6 +32,12 @@ import shutil
 import filecmp
 import errno
 
+try:
+    import boto3
+    from moto import mock_s3
+except ImportError:
+    mock_s3 = None
+
 
 class S3ToHiveTransferTest(unittest.TestCase):
 
@@ -128,10 +134,6 @@ class S3ToHiveTransferTest(unittest.TestCase):
         key = ext + "_" + ('h' if header else 'nh')
         return key
 
-    def _cp_file_contents(self, fn_src, fn_dest):
-        with open(fn_src, 'rb') as f_src, open(fn_dest, 'wb') as f_dest:
-            shutil.copyfileobj(f_src, f_dest)
-
     def _check_file_equality(self, fn_1, fn_2, ext):
         # gz files contain mtime and filename in the header that
         # causes filecmp to return False even if contents are identical
@@ -205,13 +207,15 @@ class S3ToHiveTransferTest(unittest.TestCase):
                         msg="bz2 Compressed file not as expected")
 
     @unittest.skipIf(mock is None, 'mock package not present')
+    @unittest.skipIf(mock_s3 is None, 'moto package not present')
     @mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook')
-    @mock.patch('airflow.operators.s3_to_hive_operator.S3Hook')
-    def test_execute(self, mock_s3hook, mock_hiveclihook):
+    @mock_s3
+    def test_execute(self, mock_hiveclihook):
+        conn = boto3.client('s3')
+        conn.create_bucket(Bucket='bucket')
+
         # Testing txt, zip, bz2 files with and without header row
-        for test in product(['.txt', '.gz', '.bz2'], [True, False]):
-            ext = test[0]
-            has_header = test[1]
+        for (ext, has_header) in product(['.txt', '.gz', '.bz2'], [True, False]):
             self.kwargs['headers'] = has_header
             self.kwargs['check_headers'] = has_header
             logging.info("Testing {0} format {1} header".
@@ -219,15 +223,13 @@ class S3ToHiveTransferTest(unittest.TestCase):
                                 ('with' if has_header else 'without'))
                          )
             self.kwargs['input_compressed'] = (False if ext == '.txt' else True)
-            self.kwargs['s3_key'] = self.s3_key + ext
+            self.kwargs['s3_key'] = 's3://bucket/' + self.s3_key + ext
             ip_fn = self._get_fn(ext, self.kwargs['headers'])
             op_fn = self._get_fn(ext, False)
-            # Mock s3 object returned by S3Hook
-            mock_s3_object = mock.Mock(key=self.kwargs['s3_key'])
-            mock_s3_object.get_contents_to_file.side_effect = \
-                lambda dest_file: \
-                self._cp_file_contents(ip_fn, dest_file.name)
-            mock_s3hook().get_key.return_value = mock_s3_object
+
+            # Upload the file into the Mocked S3 bucket
+            conn.upload_file(ip_fn, 'bucket', self.s3_key + ext)
+
             # file paramter to HiveCliHook.load_file is compared
             # against expected file oputput
             mock_hiveclihook().load_file.side_effect = \