You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2020/11/23 08:20:54 UTC

[airflow] branch master updated: [AIRFLOW-5115] Bugfix for S3KeySensor failing to accept template_fields (#12389)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new ed09915  [AIRFLOW-5115] Bugfix for S3KeySensor failing to accept template_fields (#12389)
ed09915 is described below

commit ed09915a02b9b99e60689e647452addaab1688fc
Author: Dmitriy Synkov <30...@users.noreply.github.com>
AuthorDate: Mon Nov 23 03:18:58 2020 -0500

    [AIRFLOW-5115] Bugfix for S3KeySensor failing to accept template_fields (#12389)
---
 airflow/providers/amazon/aws/sensors/s3_key.py    | 31 +++++++--------
 tests/providers/amazon/aws/sensors/test_s3_key.py | 47 ++++++++++++++++++++---
 2 files changed, 58 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/s3_key.py b/airflow/providers/amazon/aws/sensors/s3_key.py
index 9eab08e..18cb20f 100644
--- a/airflow/providers/amazon/aws/sensors/s3_key.py
+++ b/airflow/providers/amazon/aws/sensors/s3_key.py
@@ -69,30 +69,31 @@ class S3KeySensor(BaseSensorOperator):
         **kwargs,
     ):
         super().__init__(**kwargs)
-        # Parse
-        if bucket_name is None:
-            parsed_url = urlparse(bucket_key)
+
+        self.bucket_name = bucket_name
+        self.bucket_key = bucket_key
+        self.wildcard_match = wildcard_match
+        self.aws_conn_id = aws_conn_id
+        self.verify = verify
+        self.hook: Optional[S3Hook] = None
+
+    def poke(self, context):
+
+        if self.bucket_name is None:
+            parsed_url = urlparse(self.bucket_key)
             if parsed_url.netloc == '':
-                raise AirflowException('Please provide a bucket_name')
-            else:
-                bucket_name = parsed_url.netloc
-                bucket_key = parsed_url.path.lstrip('/')
+                raise AirflowException('If key is a relative path from root, please provide a bucket_name')
+            self.bucket_name = parsed_url.netloc
+            self.bucket_key = parsed_url.path.lstrip('/')
         else:
-            parsed_url = urlparse(bucket_key)
+            parsed_url = urlparse(self.bucket_key)
             if parsed_url.scheme != '' or parsed_url.netloc != '':
                 raise AirflowException(
                     'If bucket_name is provided, bucket_key'
                     + ' should be relative path from root'
                     + ' level, rather than a full s3:// url'
                 )
-        self.bucket_name = bucket_name
-        self.bucket_key = bucket_key
-        self.wildcard_match = wildcard_match
-        self.aws_conn_id = aws_conn_id
-        self.verify = verify
-        self.hook: Optional[S3Hook] = None
 
-    def poke(self, context):
         self.log.info('Poking for key : s3://%s/%s', self.bucket_name, self.bucket_key)
         if self.wildcard_match:
             return self.get_hook().check_for_wildcard_key(self.bucket_key, self.bucket_name)
diff --git a/tests/providers/amazon/aws/sensors/test_s3_key.py b/tests/providers/amazon/aws/sensors/test_s3_key.py
index c1df7e7..be88b33 100644
--- a/tests/providers/amazon/aws/sensors/test_s3_key.py
+++ b/tests/providers/amazon/aws/sensors/test_s3_key.py
@@ -17,11 +17,15 @@
 # under the License.
 
 import unittest
+from datetime import datetime
 from unittest import mock
 
 from parameterized import parameterized
 
 from airflow.exceptions import AirflowException
+from airflow.models import TaskInstance
+from airflow.models.dag import DAG
+from airflow.models.variable import Variable
 from airflow.providers.amazon.aws.sensors.s3_key import S3KeySensor
 
 
@@ -32,8 +36,9 @@ class TestS3KeySensor(unittest.TestCase):
         and bucket_key is provided as relative path rather than s3:// url.
         :return:
         """
+        op = S3KeySensor(task_id='s3_key_sensor', bucket_key="file_in_bucket")
         with self.assertRaises(AirflowException):
-            S3KeySensor(task_id='s3_key_sensor', bucket_key="file_in_bucket")
+            op.poke(None)
 
     def test_bucket_name_provided_and_bucket_key_is_s3_url(self):
         """
@@ -41,10 +46,11 @@ class TestS3KeySensor(unittest.TestCase):
         while bucket_key is provided as a full s3:// url.
         :return:
         """
+        op = S3KeySensor(
+            task_id='s3_key_sensor', bucket_key="s3://test_bucket/file", bucket_name='test_bucket'
+        )
         with self.assertRaises(AirflowException):
-            S3KeySensor(
-                task_id='s3_key_sensor', bucket_key="s3://test_bucket/file", bucket_name='test_bucket'
-            )
+            op.poke(None)
 
     @parameterized.expand(
         [
@@ -52,16 +58,47 @@ class TestS3KeySensor(unittest.TestCase):
             ['key', 'bucket', 'key', 'bucket'],
         ]
     )
-    def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket):
+    @mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook')
+    def test_parse_bucket_key(self, key, bucket, parsed_key, parsed_bucket, mock_hook):
+        mock_hook.return_value.check_for_key.return_value = False
+
         op = S3KeySensor(
             task_id='s3_key_sensor',
             bucket_key=key,
             bucket_name=bucket,
         )
+
+        op.poke(None)
+
         self.assertEqual(op.bucket_key, parsed_key)
         self.assertEqual(op.bucket_name, parsed_bucket)
 
     @mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook')
+    def test_parse_bucket_key_from_jinja(self, mock_hook):
+        mock_hook.return_value.check_for_key.return_value = False
+
+        Variable.set("test_bucket_key", "s3://bucket/key")
+
+        execution_date = datetime(2020, 1, 1)
+
+        dag = DAG("test_s3_key", start_date=execution_date)
+        op = S3KeySensor(
+            task_id='s3_key_sensor',
+            bucket_key='{{ var.value.test_bucket_key }}',
+            bucket_name=None,
+            dag=dag,
+        )
+
+        ti = TaskInstance(task=op, execution_date=execution_date)
+        context = ti.get_template_context()
+        ti.render_templates(context)
+
+        op.poke(None)
+
+        self.assertEqual(op.bucket_key, "key")
+        self.assertEqual(op.bucket_name, "bucket")
+
+    @mock.patch('airflow.providers.amazon.aws.sensors.s3_key.S3Hook')
     def test_poke(self, mock_hook):
         op = S3KeySensor(task_id='s3_key_sensor', bucket_key='s3://test_bucket/file')