You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by GitBox <gi...@apache.org> on 2018/12/17 23:51:00 UTC

[GitHub] stale[bot] closed pull request #2559: [AIRFLOW-1558] Py3 fix for S3FileTransformOperator

stale[bot] closed pull request #2559: [AIRFLOW-1558] Py3 fix for S3FileTransformOperator
URL: https://github.com/apache/incubator-airflow/pull/2559
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py
index 1cdd0e5e48..3f2a2a990a 100644
--- a/airflow/operators/s3_file_transform_operator.py
+++ b/airflow/operators/s3_file_transform_operator.py
@@ -14,6 +14,7 @@
 
 import logging
 from tempfile import NamedTemporaryFile
+import six
 import subprocess
 
 from airflow.exceptions import AirflowException
@@ -81,7 +82,7 @@ def execute(self, context):
             raise AirflowException("The source key {0} does not exist"
                             "".format(self.source_s3_key))
         source_s3_key_object = source_s3.get_key(self.source_s3_key)
-        with NamedTemporaryFile("w") as f_source, NamedTemporaryFile("w") as f_dest:
+        with NamedTemporaryFile("wb") as f_source, NamedTemporaryFile("wb") as f_dest:
             logging.info("Dumping S3 file {0} contents to local file {1}"
                          "".format(self.source_s3_key, f_source.name))
             source_s3_key_object.get_contents_to_file(f_source)
@@ -91,9 +92,13 @@ def execute(self, context):
                 [self.transform_script, f_source.name, f_dest.name],
                 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
             (transform_script_stdoutdata, transform_script_stderrdata) = transform_script_process.communicate()
+            if six.PY3:
+                transform_script_stdoutdata = transform_script_stdoutdata.decode()
             logging.info("Transform script stdout "
                          "" + transform_script_stdoutdata)
             if transform_script_process.returncode > 0:
+                if six.PY3:
+                    transform_script_stderrdata = transform_script_stderrdata.decode()
                 raise AirflowException("Transform script failed "
                                 "" + transform_script_stderrdata)
             else:
diff --git a/tests/operators/test_s3_file_transform_operator.py b/tests/operators/test_s3_file_transform_operator.py
new file mode 100644
index 0000000000..0bd6bdf230
--- /dev/null
+++ b/tests/operators/test_s3_file_transform_operator.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import unittest
+
+import boto
+from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.operators.s3_file_transform_operator import S3FileTransformOperator
+from airflow.utils import db
+try:
+    from moto import mock_s3_deprecated
+except ImportError:
+    mock_s3_deprecated = None
+
+
+DEFAULT_CONN_ID = "s3_default"
+
+
+class S3FileTransformTest(unittest.TestCase):
+    """
+    Tests for the S3 file transform operator.
+    """
+
+    @db.provide_session
+    def setUp(self, session=None):
+        self.mock_s3 = mock_s3_deprecated()
+        self.mock_s3.start()
+        self.s3_connection = session.query(Connection).filter(
+            Connection.conn_id == DEFAULT_CONN_ID
+        ).first()
+        if self.s3_connection is None:
+            self.s3_connection = Connection(conn_id=DEFAULT_CONN_ID, conn_type="s3")
+            session.add(self.s3_connection)
+            session.commit()
+
+    def tearDown(self):
+        self.mock_s3.stop()
+
+    @unittest.skipIf(mock_s3_deprecated is None, 'mock package not present')
+    def test_execute(self):
+        source_key = "/source/key"
+        source_bucket_name = "source-bucket"
+        dest_key = "/dest/key"
+        dest_bucket_name = "dest-bucket"
+        key_data = u"foobar"
+        # set up mock data
+        s3_client = boto.connect_s3()
+        source_bucket = s3_client.create_bucket(source_bucket_name)
+        dest_bucket = s3_client.create_bucket(dest_bucket_name)
+        source_obj = boto.s3.key.Key(source_bucket)
+        source_obj.key = source_key
+        source_obj.set_contents_from_string(key_data)
+        # Invoke .execute
+        s3_xform_task = S3FileTransformOperator(
+            task_id="s3_file_xform",
+            source_s3_key="s3://{}{}".format(source_bucket_name, source_key),
+            dest_s3_key="s3://{}{}".format(dest_bucket_name, dest_key),
+            transform_script="cp")
+        s3_xform_task.execute(None)
+        # ensure the data is correct
+        result = dest_bucket.get_key(dest_key)
+        stored = result.get_contents_as_string()
+        if six.PY3:
+            stored = stored.decode()
+        self.assertEqual(stored, key_data)
+
+
+if __name__ == "__main__":
+    unittest.main()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services