You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2020/05/28 10:13:39 UTC

[airflow] branch master updated: Add script_args for S3FileTransformOperator (#9019)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ed171b  Add script_args for S3FileTransformOperator (#9019)
1ed171b is described below

commit 1ed171bfb265ded8674058bdc425640d25f1f4fc
Author: Andrej Švec <an...@gmail.com>
AuthorDate: Thu May 28 12:12:44 2020 +0200

    Add script_args for S3FileTransformOperator (#9019)
    
    Co-authored-by: Andrej Svec <as...@slido.com>
---
 .../amazon/aws/operators/s3_file_transform.py      | 11 ++-
 .../amazon/aws/operators/test_s3_file_transform.py | 99 ++++++++++++----------
 2 files changed, 62 insertions(+), 48 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/s3_file_transform.py b/airflow/providers/amazon/aws/operators/s3_file_transform.py
index 294440c..c7f2b03 100644
--- a/airflow/providers/amazon/aws/operators/s3_file_transform.py
+++ b/airflow/providers/amazon/aws/operators/s3_file_transform.py
@@ -19,7 +19,7 @@
 import subprocess
 import sys
 from tempfile import NamedTemporaryFile
-from typing import Optional, Union
+from typing import Optional, Sequence, Union
 
 from airflow.exceptions import AirflowException
 from airflow.models import BaseOperator
@@ -52,6 +52,8 @@ class S3FileTransformOperator(BaseOperator):
     :type transform_script: str
     :param select_expression: S3 Select expression
     :type select_expression: str
+    :param script_args: arguments for transformation script (templated)
+    :type script_args: sequence of str
     :param source_aws_conn_id: source s3 connection
     :type source_aws_conn_id: str
     :param source_verify: Whether or not to verify SSL certificates for S3 connection.
@@ -76,7 +78,7 @@ class S3FileTransformOperator(BaseOperator):
     :type replace: bool
     """
 
-    template_fields = ('source_s3_key', 'dest_s3_key')
+    template_fields = ('source_s3_key', 'dest_s3_key', 'script_args')
     template_ext = ()
     ui_color = '#f9c915'
 
@@ -87,12 +89,14 @@ class S3FileTransformOperator(BaseOperator):
             dest_s3_key: str,
             transform_script: Optional[str] = None,
             select_expression=None,
+            script_args: Optional[Sequence[str]] = None,
             source_aws_conn_id: str = 'aws_default',
             source_verify: Optional[Union[bool, str]] = None,
             dest_aws_conn_id: str = 'aws_default',
             dest_verify: Optional[Union[bool, str]] = None,
             replace: bool = False,
             *args, **kwargs) -> None:
+        # pylint: disable=too-many-arguments
         super().__init__(*args, **kwargs)
         self.source_s3_key = source_s3_key
         self.source_aws_conn_id = source_aws_conn_id
@@ -103,6 +107,7 @@ class S3FileTransformOperator(BaseOperator):
         self.replace = replace
         self.transform_script = transform_script
         self.select_expression = select_expression
+        self.script_args = script_args or []
         self.output_encoding = sys.getdefaultencoding()
 
     def execute(self, context):
@@ -137,7 +142,7 @@ class S3FileTransformOperator(BaseOperator):
 
             if self.transform_script is not None:
                 process = subprocess.Popen(
-                    [self.transform_script, f_source.name, f_dest.name],
+                    [self.transform_script, f_source.name, f_dest.name, *self.script_args],
                     stdout=subprocess.PIPE,
                     stderr=subprocess.STDOUT,
                     close_fds=True
diff --git a/tests/providers/amazon/aws/operators/test_s3_file_transform.py b/tests/providers/amazon/aws/operators/test_s3_file_transform.py
index 87339f8..602eef1 100644
--- a/tests/providers/amazon/aws/operators/test_s3_file_transform.py
+++ b/tests/providers/amazon/aws/operators/test_s3_file_transform.py
@@ -53,25 +53,12 @@ class TestS3FileTransformOperator(unittest.TestCase):
     @mock_s3
     def test_execute_with_transform_script(self, mock_log, mock_popen):
         process_output = [b"Foo", b"Bar", b"Baz"]
+        self.mock_process(mock_popen, process_output=process_output)
+        input_path, output_path = self.s3_paths()
 
-        process = mock_popen.return_value
-        process.stdout.readline.side_effect = process_output
-        process.wait.return_value = None
-        process.returncode = 0
-
-        bucket = "bucket"
-        input_key = "foo"
-        output_key = "bar"
-        bio = io.BytesIO(b"input")
-
-        conn = boto3.client('s3')
-        conn.create_bucket(Bucket=bucket)
-        conn.upload_fileobj(Bucket=bucket, Key=input_key, Fileobj=bio)
-
-        s3_url = "s3://{0}/{1}"
         op = S3FileTransformOperator(
-            source_s3_key=s3_url.format(bucket, input_key),
-            dest_s3_key=s3_url.format(bucket, output_key),
+            source_s3_key=input_path,
+            dest_s3_key=output_path,
             transform_script=self.transform_script,
             replace=True,
             task_id="task_id")
@@ -84,24 +71,12 @@ class TestS3FileTransformOperator(unittest.TestCase):
     @mock.patch('subprocess.Popen')
     @mock_s3
     def test_execute_with_failing_transform_script(self, mock_popen):
-        process = mock_popen.return_value
-        process.stdout.readline.side_effect = []
-        process.wait.return_value = None
-        process.returncode = 42
+        self.mock_process(mock_popen, return_code=42)
+        input_path, output_path = self.s3_paths()
 
-        bucket = "bucket"
-        input_key = "foo"
-        output_key = "bar"
-        bio = io.BytesIO(b"input")
-
-        conn = boto3.client('s3')
-        conn.create_bucket(Bucket=bucket)
-        conn.upload_fileobj(Bucket=bucket, Key=input_key, Fileobj=bio)
-
-        s3_url = "s3://{0}/{1}"
         op = S3FileTransformOperator(
-            source_s3_key=s3_url.format(bucket, input_key),
-            dest_s3_key=s3_url.format(bucket, output_key),
+            source_s3_key=input_path,
+            dest_s3_key=output_path,
             transform_script=self.transform_script,
             replace=True,
             task_id="task_id")
@@ -111,9 +86,52 @@ class TestS3FileTransformOperator(unittest.TestCase):
 
         self.assertEqual('Transform script failed: 42', str(e.exception))
 
+    @mock.patch('subprocess.Popen')
+    @mock_s3
+    def test_execute_with_transform_script_args(self, mock_popen):
+        self.mock_process(mock_popen, process_output=[b"Foo", b"Bar", b"Baz"])
+        input_path, output_path = self.s3_paths()
+        script_args = ['arg1', 'arg2']
+
+        op = S3FileTransformOperator(
+            source_s3_key=input_path,
+            dest_s3_key=output_path,
+            transform_script=self.transform_script,
+            script_args=script_args,
+            replace=True,
+            task_id="task_id")
+        op.execute(None)
+
+        self.assertEqual(script_args, mock_popen.call_args[0][0][3:])
+
     @mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.select_key', return_value="input")
     @mock_s3
     def test_execute_with_select_expression(self, mock_select_key):
+        input_path, output_path = self.s3_paths()
+        select_expression = "SELECT * FROM s3object s"
+
+        op = S3FileTransformOperator(
+            source_s3_key=input_path,
+            dest_s3_key=output_path,
+            select_expression=select_expression,
+            replace=True,
+            task_id="task_id")
+        op.execute(None)
+
+        mock_select_key.assert_called_once_with(
+            key=input_path,
+            expression=select_expression
+        )
+
+    @staticmethod
+    def mock_process(mock_popen, return_code=0, process_output=None):
+        process = mock_popen.return_value
+        process.stdout.readline.side_effect = process_output or []
+        process.wait.return_value = None
+        process.returncode = return_code
+
+    @staticmethod
+    def s3_paths():
         bucket = "bucket"
         input_key = "foo"
         output_key = "bar"
@@ -124,16 +142,7 @@ class TestS3FileTransformOperator(unittest.TestCase):
         conn.upload_fileobj(Bucket=bucket, Key=input_key, Fileobj=bio)
 
         s3_url = "s3://{0}/{1}"
-        select_expression = "SELECT * FROM S3Object s"
-        op = S3FileTransformOperator(
-            source_s3_key=s3_url.format(bucket, input_key),
-            dest_s3_key=s3_url.format(bucket, output_key),
-            select_expression=select_expression,
-            replace=True,
-            task_id="task_id")
-        op.execute(None)
+        input_path = s3_url.format(bucket, input_key)
+        output_path = s3_url.format(bucket, output_key)
 
-        mock_select_key.assert_called_once_with(
-            key=s3_url.format(bucket, input_key),
-            expression=select_expression
-        )
+        return input_path, output_path